You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
217 lines
8.1 KiB
Python
217 lines
8.1 KiB
Python
|
|
# Copyright (C) 2019-2020 Intel Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import codecs
|
|
from collections import OrderedDict
|
|
import hashlib
|
|
import logging as log
|
|
import os
|
|
import os.path as osp
|
|
import string
|
|
|
|
from datumaro.components.extractor import (AnnotationType, DEFAULT_SUBSET_NAME,
|
|
LabelCategories
|
|
)
|
|
from datumaro.components.converter import Converter
|
|
from datumaro.util.image import encode_image
|
|
from datumaro.util.annotation_util import (max_bbox,
|
|
find_group_leader, find_instances)
|
|
from datumaro.util.mask_tools import merge_masks
|
|
from datumaro.util.tf_util import import_tf as _import_tf
|
|
|
|
from .format import DetectionApiPath
|
|
tf = _import_tf()
|
|
|
|
|
|
# filter out non-ASCII characters, otherwise training will crash
|
|
_printable = set(string.printable)
|
|
def _make_printable(s):
|
|
return ''.join(filter(lambda x: x in _printable, s))
|
|
|
|
def int64_feature(value):
|
|
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
|
|
|
|
def int64_list_feature(value):
|
|
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
|
|
|
|
def bytes_feature(value):
|
|
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
|
|
|
def bytes_list_feature(value):
|
|
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
|
|
|
|
def float_list_feature(value):
|
|
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
|
|
|
|
class TfDetectionApiConverter(Converter):
|
|
DEFAULT_IMAGE_EXT = DetectionApiPath.DEFAULT_IMAGE_EXT
|
|
|
|
@classmethod
|
|
def build_cmdline_parser(cls, **kwargs):
|
|
parser = super().build_cmdline_parser(**kwargs)
|
|
parser.add_argument('--save-masks', action='store_true',
|
|
help="Include instance masks (default: %(default)s)")
|
|
return parser
|
|
|
|
def __init__(self, extractor, save_dir, save_masks=False, **kwargs):
|
|
super().__init__(extractor, save_dir, **kwargs)
|
|
|
|
self._save_masks = save_masks
|
|
|
|
def apply(self):
|
|
os.makedirs(self._save_dir, exist_ok=True)
|
|
|
|
label_categories = self._extractor.categories().get(AnnotationType.label,
|
|
LabelCategories())
|
|
get_label = lambda label_id: label_categories.items[label_id].name \
|
|
if label_id is not None else ''
|
|
label_ids = OrderedDict((label.name, 1 + idx)
|
|
for idx, label in enumerate(label_categories.items))
|
|
map_label_id = lambda label_id: label_ids.get(get_label(label_id), 0)
|
|
self._get_label = get_label
|
|
self._get_label_id = map_label_id
|
|
|
|
subsets = self._extractor.subsets()
|
|
if len(subsets) == 0:
|
|
subsets = [ None ]
|
|
|
|
for subset_name in subsets:
|
|
if subset_name:
|
|
subset = self._extractor.get_subset(subset_name)
|
|
else:
|
|
subset_name = DEFAULT_SUBSET_NAME
|
|
subset = self._extractor
|
|
|
|
labelmap_path = osp.join(self._save_dir, DetectionApiPath.LABELMAP_FILE)
|
|
with codecs.open(labelmap_path, 'w', encoding='utf8') as f:
|
|
for label, idx in label_ids.items():
|
|
f.write(
|
|
'item {\n' +
|
|
('\tid: %s\n' % (idx)) +
|
|
("\tname: '%s'\n" % (label)) +
|
|
'}\n\n'
|
|
)
|
|
|
|
anno_path = osp.join(self._save_dir, '%s.tfrecord' % (subset_name))
|
|
with tf.io.TFRecordWriter(anno_path) as writer:
|
|
for item in subset:
|
|
tf_example = self._make_tf_example(item)
|
|
writer.write(tf_example.SerializeToString())
|
|
|
|
@staticmethod
|
|
def _find_instances(annotations):
|
|
return find_instances(a for a in annotations
|
|
if a.type in { AnnotationType.bbox, AnnotationType.mask })
|
|
|
|
def _find_instance_parts(self, group, img_width, img_height):
|
|
boxes = [a for a in group if a.type == AnnotationType.bbox]
|
|
masks = [a for a in group if a.type == AnnotationType.mask]
|
|
|
|
anns = boxes + masks
|
|
leader = find_group_leader(anns)
|
|
bbox = max_bbox(anns)
|
|
|
|
mask = None
|
|
if self._save_masks:
|
|
mask = merge_masks([m.image for m in masks])
|
|
|
|
return [leader, mask, bbox]
|
|
|
|
def _export_instances(self, instances, width, height):
|
|
xmins = [] # List of normalized left x coordinates of bounding boxes (1 per box)
|
|
xmaxs = [] # List of normalized right x coordinates of bounding boxes (1 per box)
|
|
ymins = [] # List of normalized top y coordinates of bounding boxes (1 per box)
|
|
ymaxs = [] # List of normalized bottom y coordinates of bounding boxes (1 per box)
|
|
classes_text = [] # List of class names of bounding boxes (1 per box)
|
|
classes = [] # List of class ids of bounding boxes (1 per box)
|
|
masks = [] # List of PNG-encoded instance masks (1 per box)
|
|
|
|
for leader, mask, box in instances:
|
|
label = _make_printable(self._get_label(leader.label))
|
|
classes_text.append(label.encode('utf-8'))
|
|
classes.append(self._get_label_id(leader.label))
|
|
|
|
xmins.append(box[0] / width)
|
|
xmaxs.append((box[0] + box[2]) / width)
|
|
ymins.append(box[1] / height)
|
|
ymaxs.append((box[1] + box[3]) / height)
|
|
|
|
if self._save_masks:
|
|
if mask is not None:
|
|
mask = encode_image(mask, '.png')
|
|
else:
|
|
mask = b''
|
|
masks.append(mask)
|
|
|
|
result = {}
|
|
if classes:
|
|
result = {
|
|
'image/object/bbox/xmin': float_list_feature(xmins),
|
|
'image/object/bbox/xmax': float_list_feature(xmaxs),
|
|
'image/object/bbox/ymin': float_list_feature(ymins),
|
|
'image/object/bbox/ymax': float_list_feature(ymaxs),
|
|
'image/object/class/text': bytes_list_feature(classes_text),
|
|
'image/object/class/label': int64_list_feature(classes),
|
|
}
|
|
if masks:
|
|
result['image/object/mask'] = bytes_list_feature(masks)
|
|
return result
|
|
|
|
def _make_tf_example(self, item):
|
|
features = {
|
|
'image/source_id': bytes_feature(
|
|
str(item.attributes.get('source_id') or '').encode('utf-8')
|
|
),
|
|
}
|
|
|
|
filename = self._make_image_filename(item)
|
|
features['image/filename'] = bytes_feature(filename.encode('utf-8'))
|
|
|
|
if not item.has_image:
|
|
raise Exception("Failed to export dataset item '%s': "
|
|
"item has no image info" % item.id)
|
|
height, width = item.image.size
|
|
|
|
features.update({
|
|
'image/height': int64_feature(height),
|
|
'image/width': int64_feature(width),
|
|
})
|
|
|
|
features.update({
|
|
'image/encoded': bytes_feature(b''),
|
|
'image/format': bytes_feature(b''),
|
|
'image/key/sha256': bytes_feature(b''),
|
|
})
|
|
if self._save_images:
|
|
if item.has_image and item.image.has_data:
|
|
buffer, fmt = self._save_image(item, filename)
|
|
key = hashlib.sha256(buffer).hexdigest()
|
|
|
|
features.update({
|
|
'image/encoded': bytes_feature(buffer),
|
|
'image/format': bytes_feature(fmt.encode('utf-8')),
|
|
'image/key/sha256': bytes_feature(key.encode('utf8')),
|
|
})
|
|
else:
|
|
log.warning("Item '%s' has no image" % item.id)
|
|
|
|
instances = self._find_instances(item.annotations)
|
|
instances = [self._find_instance_parts(i, width, height) for i in instances]
|
|
features.update(self._export_instances(instances, width, height))
|
|
|
|
tf_example = tf.train.Example(
|
|
features=tf.train.Features(feature=features))
|
|
|
|
return tf_example
|
|
|
|
def _save_image(self, item, path=None):
|
|
dst_ext = osp.splitext(osp.basename(path))[1]
|
|
fmt = DetectionApiPath.IMAGE_EXT_FORMAT.get(dst_ext)
|
|
if not fmt:
|
|
log.warning("Item '%s': can't find format string for the '%s' "
|
|
"image extension, the corresponding field will be empty." % \
|
|
(item.id, dst_ext))
|
|
buffer = encode_image(item.image.data, dst_ext)
|
|
return buffer, fmt |