[Datumaro] Added tf detection api tfrecord import and export (#894)

* Added tf detection api tfrecord import and export
* Added export button in dashboard
* Add tf to requirements
* Extend test
* Add tf dependency
* Require images in tfrecord export
* Add video task case handling
* Maintain image order in CVAT export
* Fix Task image id-path conversions
* Update tfrecord tests
* Extend image utilities
* Update tfrecord format
* Fix image loading bug
* Add some logs
* Add 'copy' option to project import command
* Reduce default cache size
main
zhiltsov-max 6 years ago committed by Nikita Manovich
parent 310ee37b53
commit 0f5cf2064d

@ -28,7 +28,7 @@ class CvatImagesDirExtractor(datumaro.Extractor):
id=item_id, image=lazy_image(path))
items.append((item.id, item))
items = sorted(items, key=lambda e: e[0])
items = sorted(items, key=lambda e: int(e[0]))
items = OrderedDict(items)
self._items = items
@ -74,7 +74,7 @@ class CvatTaskExtractor(datumaro.Extractor):
id=cvat_anno.frame, annotations=dm_anno)
dm_annotations.append((dm_item.id, dm_item))
dm_annotations = sorted(dm_annotations, key=lambda e: e[0])
dm_annotations = sorted(dm_annotations, key=lambda e: int(e[0]))
self._items = OrderedDict(dm_annotations)
self._subsets = None

@ -102,7 +102,7 @@ class cvat_rest_api_task_images(datumaro.Extractor):
id=item_id, image=self._make_image_loader(item_id))
items.append((item.id, item))
items = sorted(items, key=lambda e: e[0])
items = sorted(items, key=lambda e: int(e[0]))
items = OrderedDict(items)
self._items = items

@ -242,13 +242,23 @@ class TaskProject:
images_meta = {
'images': items,
}
for db_image in self._db_task.image_set.all():
frame_info = {
'id': db_image.frame,
'width': db_image.width,
'height': db_image.height,
}
items.append(frame_info)
db_video = getattr(self._db_task, 'video', None)
if db_video is not None:
for i in range(self._db_task.size):
frame_info = {
'id': str(i),
'width': db_video.width,
'height': db_video.height,
}
items.append(frame_info)
else:
for db_image in self._db_task.image_set.all():
frame_info = {
'id': db_image.frame,
'width': db_image.width,
'height': db_image.height,
}
items.append(frame_info)
with open(osp.join(save_dir, 'config.json'), 'w') as config_file:
json.dump(config, config_file)
@ -385,6 +395,11 @@ EXPORT_FORMATS = [
'tag': 'yolo',
'is_default': False,
},
{
'name': 'TF Detection API TFrecord',
'tag': 'tf_detection_api',
'is_default': False,
},
]
def get_export_formats():

@ -90,12 +90,7 @@ class Task(models.Model):
def get_image_frame(image_path):
assert image_path.endswith('.jpg')
index = os.path.splitext(os.path.basename(image_path))[0]
path = os.path.dirname(image_path)
d2 = os.path.basename(path)
d1 = os.path.basename(os.path.dirname(path))
return int(d1) * 10000 + int(d2) * 100 + int(index)
return int(index)
def get_frame_step(self):
match = re.search("step\s*=\s*([1-9]\d*)", self.frame_filter)

@ -0,0 +1,25 @@
# Copyright (C) 2018 Intel Corporation
#
# SPDX-License-Identifier: MIT
import os.path as osp
from django.test import TestCase
from cvat.apps.engine.models import Task
class TaskModelTest(TestCase):
def test_frame_id_path_conversions(self):
task_id = 1
task = Task(task_id)
for i in [10 ** p for p in range(6)]:
src_path_expected = osp.join(
str(i // 10000), str(i // 100), '%s.jpg' % i)
src_path = task.get_frame_path(i)
dst_frame = task.get_image_frame(src_path)
self.assertTrue(src_path.endswith(src_path_expected),
'%s vs. %s' % (src_path, src_path_expected))
self.assertEqual(i, dst_frame)

@ -48,7 +48,7 @@ def build_import_parser(parser):
import datumaro.components.importers as importers_module
importers_list = [name for name, cls in importers_module.items]
parser.add_argument('source_path',
parser.add_argument('-s', '--source', required=True,
help="Path to import a project from")
parser.add_argument('-f', '--format', required=True,
help="Source project format (options: %s)" % (', '.join(importers_list)))
@ -60,6 +60,8 @@ def build_import_parser(parser):
help="Name of the new project (default: same as project dir)")
parser.add_argument('--overwrite', action='store_true',
help="Overwrite existing files in the save directory")
parser.add_argument('--copy', action='store_true',
help="Make a deep copy instead of saving source links")
return parser
def import_command(args):
@ -74,14 +76,19 @@ def import_command(args):
project_name = osp.basename(project_dir)
log.info("Importing project from '%s' as '%s'" % \
(args.source_path, args.format))
(args.source, args.format))
source_path = osp.abspath(args.source_path)
project = Project.import_from(source_path, args.format)
source = osp.abspath(args.source)
project = Project.import_from(source, args.format)
project.config.project_name = project_name
project.config.project_dir = project_dir
project = project.make_dataset()
project.save(merge=True, save_images=False)
dataset = project.make_dataset()
if args.copy:
log.info("Cloning data...")
dataset.save(merge=True, save_images=True)
else:
project.save()
log.info("Project has been created at '%s'" % (project_dir))

@ -25,6 +25,10 @@ from datumaro.components.converters.voc import (
from datumaro.components.converters.yolo import YoloConverter
from datumaro.components.converters.tfrecord import (
DetectionApiConverter,
)
items = [
('datumaro', DatumaroConverter),
@ -44,4 +48,6 @@ items = [
('voc_layout', VocLayoutConverter),
('yolo', YoloConverter),
('tf_detection_api', DetectionApiConverter),
]

@ -0,0 +1,146 @@
# Copyright (C) 2019 Intel Corporation
#
# SPDX-License-Identifier: MIT
import codecs
from collections import OrderedDict
import os
import os.path as osp
import string
from datumaro.components.extractor import AnnotationType, DEFAULT_SUBSET_NAME
from datumaro.components.formats.tfrecord import DetectionApiPath
from datumaro.util.image import encode_image
from datumaro.util.tf_util import import_tf as _import_tf
# we need it to filter out non-ASCII characters, otherwise training will crash
_printable = set(string.printable)
def _make_printable(s):
return ''.join(filter(lambda x: x in _printable, s))
def _make_tf_example(item, get_label_id, get_label, save_images=False):
tf = _import_tf()
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def int64_list_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def bytes_list_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
features = {
'image/source_id': bytes_feature(str(item.id).encode('utf-8')),
'image/filename': bytes_feature(
('%s%s' % (item.id, DetectionApiPath.IMAGE_EXT)).encode('utf-8')),
}
if not item.has_image:
raise Exception(
"Failed to export dataset item '%s': item has no image" % item.id)
height, width, _ = item.image.shape
features.update({
'image/height': int64_feature(height),
'image/width': int64_feature(width),
})
if save_images and item.has_image:
fmt = DetectionApiPath.IMAGE_FORMAT
buffer = encode_image(item.image, DetectionApiPath.IMAGE_EXT)
features.update({
'image/encoded': bytes_feature(buffer),
'image/format': bytes_feature(fmt.encode('utf-8')),
})
xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
xmaxs = [] # List of normalized right x coordinates in bounding box (1 per box)
ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
ymaxs = [] # List of normalized bottom y coordinates in bounding box (1 per box)
classes_text = [] # List of string class name of bounding box (1 per box)
classes = [] # List of integer class id of bounding box (1 per box)
boxes = [ann for ann in item.annotations if ann.type is AnnotationType.bbox]
for box in boxes:
box_label = _make_printable(get_label(box.label))
xmins.append(box.points[0] / width)
xmaxs.append(box.points[2] / width)
ymins.append(box.points[1] / height)
ymaxs.append(box.points[3] / height)
classes_text.append(box_label.encode('utf-8'))
classes.append(get_label_id(box.label))
if boxes:
features.update({
'image/object/bbox/xmin': float_list_feature(xmins),
'image/object/bbox/xmax': float_list_feature(xmaxs),
'image/object/bbox/ymin': float_list_feature(ymins),
'image/object/bbox/ymax': float_list_feature(ymaxs),
'image/object/class/text': bytes_list_feature(classes_text),
'image/object/class/label': int64_list_feature(classes),
})
tf_example = tf.train.Example(
features=tf.train.Features(feature=features))
return tf_example
class DetectionApiConverter:
def __init__(self, save_images=True):
self.save_images = save_images
def __call__(self, extractor, save_dir):
tf = _import_tf()
os.makedirs(save_dir, exist_ok=True)
subsets = extractor.subsets()
if len(subsets) == 0:
subsets = [ None ]
for subset_name in subsets:
if subset_name:
subset = extractor.get_subset(subset_name)
else:
subset_name = DEFAULT_SUBSET_NAME
subset = extractor
label_categories = subset.categories()[AnnotationType.label]
get_label = lambda label_id: label_categories.items[label_id].name \
if label_id is not None else ''
label_ids = OrderedDict((label.name, 1 + idx)
for idx, label in enumerate(label_categories.items))
map_label_id = lambda label_id: label_ids.get(get_label(label_id), 0)
labelmap_path = osp.join(save_dir, DetectionApiPath.LABELMAP_FILE)
with codecs.open(labelmap_path, 'w', encoding='utf8') as f:
for label, idx in label_ids.items():
f.write(
'item {\n' +
('\tid: %s\n' % (idx)) +
("\tname: '%s'\n" % (label)) +
'}\n\n'
)
anno_path = osp.join(save_dir, '%s.tfrecord' % (subset_name))
with tf.io.TFRecordWriter(anno_path) as writer:
for item in subset:
tf_example = _make_tf_example(
item,
get_label=get_label,
get_label_id=map_label_id,
save_images=self.save_images,
)
writer.write(tf_example.SerializeToString())

@ -30,6 +30,11 @@ from datumaro.components.extractors.yolo import (
YoloExtractor,
)
from datumaro.components.extractors.tfrecord import (
DetectionApiExtractor,
)
items = [
('datumaro', DatumaroExtractor),
@ -52,4 +57,6 @@ items = [
('voc_comp_9_10', VocComp_9_10_Extractor),
('yolo', YoloExtractor),
('tf_detection_api', DetectionApiExtractor),
]

@ -0,0 +1,206 @@
# Copyright (C) 2019 Intel Corporation
#
# SPDX-License-Identifier: MIT
from collections import OrderedDict
import numpy as np
import os.path as osp
import re
from datumaro.components.extractor import AnnotationType, DEFAULT_SUBSET_NAME, \
LabelCategories, BboxObject, DatasetItem, Extractor
from datumaro.components.formats.tfrecord import DetectionApiPath
from datumaro.util.image import lazy_image, decode_image
from datumaro.util.tf_util import import_tf as _import_tf
def clamp(value, _min, _max):
return max(min(_max, value), _min)
class DetectionApiExtractor(Extractor):
class Subset(Extractor):
def __init__(self, name, parent):
super().__init__()
self._name = name
self._parent = parent
self.items = OrderedDict()
def __iter__(self):
for item in self.items.values():
yield item
def __len__(self):
return len(self.items)
def categories(self):
return self._parent.categories()
def __init__(self, path, images_dir=None):
super().__init__()
root_dir = osp.dirname(osp.abspath(path))
if osp.basename(root_dir) == DetectionApiPath.ANNOTATIONS_DIR:
root_dir = osp.dirname(root_dir)
images_dir = osp.join(root_dir, DetectionApiPath.IMAGES_DIR)
if not osp.isdir(images_dir):
images_dir = None
self._images_dir = images_dir
self._subsets = {}
subset_name = osp.splitext(osp.basename(path))[0]
if subset_name == DEFAULT_SUBSET_NAME:
subset_name = None
subset = DetectionApiExtractor.Subset(subset_name, self)
items, labels = self._parse_tfrecord_file(path, subset_name, images_dir)
subset.items = items
self._subsets[subset_name] = subset
label_categories = LabelCategories()
labels = sorted(labels.items(), key=lambda item: item[1])
for label, _ in labels:
label_categories.add(label)
self._categories = {
AnnotationType.label: label_categories
}
@classmethod
def _parse_labelmap(cls, text):
id_pattern = r'(?:id\s*:\s*(?P<id>\d+))'
name_pattern = r'(?:name\s*:\s*[\'\"](?P<name>.*?)[\'\"])'
entry_pattern = r'(\{(?:[\s\n]*(?:%(id)s|%(name)s)[\s\n]*){2}\})+' % \
{'id': id_pattern, 'name': name_pattern}
matches = re.finditer(entry_pattern, text)
labelmap = {}
for match in matches:
label_id = match.group('id')
label_name = match.group('name')
if label_id is not None and label_name is not None:
labelmap[label_name] = int(label_id)
return labelmap
@classmethod
def _parse_tfrecord_file(cls, filepath, subset_name, images_dir):
tf = _import_tf()
dataset = tf.data.TFRecordDataset(filepath)
features = {
'image/filename': tf.io.FixedLenFeature([], tf.string),
'image/source_id': tf.io.FixedLenFeature([], tf.string),
'image/height': tf.io.FixedLenFeature([], tf.int64),
'image/width': tf.io.FixedLenFeature([], tf.int64),
'image/encoded': tf.io.FixedLenFeature([], tf.string),
'image/format': tf.io.FixedLenFeature([], tf.string),
# Object boxes and classes.
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
'image/object/class/label': tf.io.VarLenFeature(tf.int64),
'image/object/class/text': tf.io.VarLenFeature(tf.string),
'image/object/mask': tf.io.VarLenFeature(tf.string),
}
dataset_labels = OrderedDict()
labelmap_path = osp.join(osp.dirname(filepath),
DetectionApiPath.LABELMAP_FILE)
if osp.exists(labelmap_path):
with open(labelmap_path, 'r', encoding='utf-8') as f:
labelmap_text = f.read()
dataset_labels.update({ label: id - 1
for label, id in cls._parse_labelmap(labelmap_text).items()
})
dataset_items = OrderedDict()
for record in dataset:
parsed_record = tf.io.parse_single_example(record, features)
frame_id = parsed_record['image/source_id'].numpy().decode('utf-8')
frame_filename = \
parsed_record['image/filename'].numpy().decode('utf-8')
frame_height = tf.cast(
parsed_record['image/height'], tf.int64).numpy().item()
frame_width = tf.cast(
parsed_record['image/width'], tf.int64).numpy().item()
frame_image = parsed_record['image/encoded'].numpy()
frame_format = parsed_record['image/format'].numpy().decode('utf-8')
xmins = tf.sparse.to_dense(
parsed_record['image/object/bbox/xmin']).numpy()
ymins = tf.sparse.to_dense(
parsed_record['image/object/bbox/ymin']).numpy()
xmaxs = tf.sparse.to_dense(
parsed_record['image/object/bbox/xmax']).numpy()
ymaxs = tf.sparse.to_dense(
parsed_record['image/object/bbox/ymax']).numpy()
label_ids = tf.sparse.to_dense(
parsed_record['image/object/class/label']).numpy()
labels = tf.sparse.to_dense(
parsed_record['image/object/class/text'],
default_value=b'').numpy()
for label, label_id in zip(labels, label_ids):
label = label.decode('utf-8')
if not label:
continue
if label_id <= 0:
continue
if label in dataset_labels:
continue
dataset_labels[label] = label_id - 1
item_id = frame_id
if not item_id:
item_id = osp.splitext(frame_filename)[0]
annotations = []
for index, shape in enumerate(
np.dstack((labels, xmins, ymins, xmaxs, ymaxs))[0]):
label = shape[0].decode('utf-8')
x = clamp(shape[1] * frame_width, 0, frame_width)
y = clamp(shape[2] * frame_height, 0, frame_height)
w = clamp(shape[3] * frame_width, 0, frame_width) - x
h = clamp(shape[4] * frame_height, 0, frame_height) - y
annotations.append(BboxObject(x, y, w, h,
label=dataset_labels.get(label, None), id=index
))
image = None
if image is None and frame_image and frame_format:
image = lazy_image(frame_image, loader=decode_image)
if image is None and frame_filename and images_dir:
image_path = osp.join(images_dir, frame_filename)
if osp.exists(image_path):
image = lazy_image(image_path)
dataset_items[item_id] = DatasetItem(id=item_id, subset=subset_name,
image=image, annotations=annotations)
return dataset_items, dataset_labels
def categories(self):
return self._categories
def __iter__(self):
for subset in self._subsets.values():
for item in subset:
yield item
def __len__(self):
length = 0
for subset in self._subsets.values():
length += len(subset)
return length
def subsets(self):
return list(self._subsets)
def get_subset(self, name):
return self._subsets[name]
def get(self, item_id, subset=None, path=None):
if path is not None:
return None
return self.get_subset(subset).items.get(item_id, None)

@ -0,0 +1,13 @@
# Copyright (C) 2019 Intel Corporation
#
# SPDX-License-Identifier: MIT
class DetectionApiPath:
IMAGES_DIR = 'images'
ANNOTATIONS_DIR = 'annotations'
IMAGE_EXT = '.jpg'
IMAGE_FORMAT = 'jpeg'
LABELMAP_FILE = 'label_map.pbtxt'

@ -14,6 +14,11 @@ from datumaro.components.importers.voc import (
VocResultsImporter,
)
from datumaro.components.importers.tfrecord import (
DetectionApiImporter,
)
items = [
('datumaro', DatumaroImporter),
@ -21,4 +26,6 @@ items = [
('voc', VocImporter),
('voc_results', VocResultsImporter),
('tf_detection_api', DetectionApiImporter),
]

@ -0,0 +1,35 @@
# Copyright (C) 2019 Intel Corporation
#
# SPDX-License-Identifier: MIT
from glob import glob
import os.path as osp
class DetectionApiImporter:
EXTRACTOR_NAME = 'tf_detection_api'
def __call__(self, path):
from datumaro.components.project import Project # cyclic import
project = Project()
subset_paths = glob(osp.join(path, '*.tfrecord'))
for subset_path in subset_paths:
if not osp.isfile(subset_path):
continue
subset_name = osp.splitext(osp.basename(subset_path))[0]
project.add_source(subset_name, {
'url': subset_path,
'format': self.EXTRACTOR_NAME,
})
if len(project.config.sources) == 0:
raise Exception(
"Failed to find 'tf_detection_api' dataset at '%s'" % path)
return project

@ -351,6 +351,7 @@ class ProjectDataset(Extractor):
# merge items
subsets = defaultdict(lambda: Subset(self))
for source_name, source in self._sources.items():
log.info("Loading '%s' source contents..." % source_name)
for item in source:
if dataset_filter and not dataset_filter(item):
continue
@ -360,7 +361,7 @@ class ProjectDataset(Extractor):
image = None
if existing_item.has_image:
# TODO: think of image comparison
image = lambda: existing_item.image
image = self._lazy_image(existing_item)
path = existing_item.path
if item.path != path:
@ -386,6 +387,7 @@ class ProjectDataset(Extractor):
# override with our items, fallback to existing images
if own_source is not None:
log.info("Loading own dataset...")
for item in own_source:
if dataset_filter and not dataset_filter(item):
continue
@ -396,7 +398,7 @@ class ProjectDataset(Extractor):
image = None
if existing_item.has_image:
# TODO: think of image comparison
image = lambda: existing_item.image
image = self._lazy_image(existing_item)
item = DatasetItemWrapper(item=item, path=None,
annotations=item.annotations, image=image)
@ -410,6 +412,11 @@ class ProjectDataset(Extractor):
self._length = None
@staticmethod
def _lazy_image(item):
# NOTE: avoid https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result
return lambda: item.image
@staticmethod
def _merge_anno(a, b):
from itertools import chain

@ -5,6 +5,7 @@
# pylint: disable=unused-import
from io import BytesIO
import numpy as np
from enum import Enum
@ -42,20 +43,76 @@ def load_image(path):
assert image.shape[2] in [1, 3, 4]
return image
def save_image(path, image):
def save_image(path, image, params=None):
if _IMAGE_BACKEND == _IMAGE_BACKENDS.cv2:
import cv2
cv2.imwrite(path, image)
cv2.imwrite(path, image, params=params)
elif _IMAGE_BACKEND == _IMAGE_BACKENDS.PIL:
from PIL import Image
if not params:
params = {}
image = image.astype(np.uint8)
if len(image.shape) == 3 and image.shape[2] in [3, 4]:
image[:, :, :3] = image[:, :, 2::-1] # BGR to RGB
image = Image.fromarray(image)
image.save(path, **params)
else:
raise NotImplementedError()
def encode_image(image, ext, params=None):
if _IMAGE_BACKEND == _IMAGE_BACKENDS.cv2:
import cv2
if not ext.startswith('.'):
ext = '.' + ext
if ext.upper() == '.JPG':
params = [ int(cv2.IMWRITE_JPEG_QUALITY), 75 ]
success, result = cv2.imencode(ext, image, params=params)
if not success:
raise Exception("Failed to encode image to '%s' format" % (ext))
return result.tobytes()
elif _IMAGE_BACKEND == _IMAGE_BACKENDS.PIL:
from PIL import Image
if ext.startswith('.'):
ext = ext[1:]
if not params:
params = {}
image = image.astype(np.uint8)
if len(image.shape) == 3 and image.shape[2] in [3, 4]:
image[:, :, :3] = image[:, :, 2::-1] # BGR to RGB
image = Image.fromarray(image)
image.save(path)
with BytesIO() as buffer:
image.save(buffer, format=ext, **params)
return buffer.getvalue()
else:
raise NotImplementedError()
def decode_image(image_bytes):
if _IMAGE_BACKEND == _IMAGE_BACKENDS.cv2:
import cv2
image = np.frombuffer(image_bytes, dtype=np.uint8)
image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED)
image = image.astype(np.float32)
elif _IMAGE_BACKEND == _IMAGE_BACKENDS.PIL:
from PIL import Image
image = Image.open(BytesIO(image_bytes))
image = np.asarray(image, dtype=np.float32)
if len(image.shape) == 3 and image.shape[2] in [3, 4]:
image[:, :, :3] = image[:, :, 2::-1] # RGB to BGR
else:
raise NotImplementedError()
assert len(image.shape) == 3
assert image.shape[2] in [1, 3, 4]
return image
class lazy_image:
def __init__(self, path, loader=load_image, cache=None):

@ -3,7 +3,7 @@ from collections import OrderedDict
_instance = None
DEFAULT_CAPACITY = 1000
DEFAULT_CAPACITY = 2
class ImageCache:
@staticmethod

@ -0,0 +1,38 @@
# Copyright (C) 2019 Intel Corporation
#
# SPDX-License-Identifier: MIT
def import_tf():
import sys
tf = sys.modules.get('tensorflow', None)
if tf is not None:
return tf
# Reduce output noise, https://stackoverflow.com/questions/38073432/how-to-suppress-verbose-tensorflow-logging
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
try:
tf.get_logger().setLevel('WARNING')
except AttributeError:
pass
try:
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.WARN)
except AttributeError:
pass
# Enable eager execution in early versions to unlock dataset operations
try:
tf.compat.v1.enable_eager_execution()
except AttributeError:
pass
try:
tf.enable_eager_execution()
except AttributeError:
pass
return tf

@ -7,3 +7,4 @@ Pillow>=6.1.0
pycocotools>=2.0.0
PyYAML>=5.1.1
tensorboardX>=1.8
tensorflow>=1.12.0

@ -57,6 +57,7 @@ setuptools.setup(
'PyYAML',
'pycocotools',
'tensorboardX',
'tensorflow',
],
entry_points={
'console_scripts': [

@ -15,25 +15,32 @@ class ImageTest(TestCase):
def tearDown(self):
image_module._IMAGE_BACKEND = self.default_backend
def _test_can_save_and_load(self, src_image, path,
save_backend=None, load_backend=None):
if save_backend:
image_module._IMAGE_BACKEND = save_backend
image_module.save_image(path, src_image)
if load_backend:
image_module._IMAGE_BACKEND = load_backend
dst_image = image_module.load_image(path)
self.assertTrue(np.all(src_image == dst_image), 'save: %s, load: %s' % \
(save_backend, load_backend))
def test_save_and_load_backends(self):
backends = image_module._IMAGE_BACKENDS
for save_backend, load_backend in product(backends, backends):
with TestDir() as test_dir:
src_image = np.random.randint(0, 255 + 1, (2, 4, 3))
image_path = osp.join(test_dir.path, 'img.png')
path = osp.join(test_dir.path, 'img.png') # lossless
image_module._IMAGE_BACKEND = save_backend
image_module.save_image(path, src_image)
image_module._IMAGE_BACKEND = load_backend
dst_image = image_module.load_image(path)
self.assertTrue(np.all(src_image == dst_image),
'save: %s, load: %s' % (save_backend, load_backend))
def test_encode_and_decode_backends(self):
backends = image_module._IMAGE_BACKENDS
for save_backend, load_backend in product(backends, backends):
src_image = np.random.randint(0, 255 + 1, (2, 4, 3))
image_module._IMAGE_BACKEND = save_backend
buffer = image_module.encode_image(src_image, '.png') # lossless
image_module._IMAGE_BACKEND = load_backend
dst_image = image_module.decode_image(buffer)
self._test_can_save_and_load(src_image, image_path,
save_backend, load_backend)
self.assertTrue(np.all(src_image == dst_image),
'save: %s, load: %s' % (save_backend, load_backend))

@ -0,0 +1,151 @@
import numpy as np
from unittest import TestCase
from datumaro.components.project import Project
from datumaro.components.extractor import (Extractor, DatasetItem,
AnnotationType, BboxObject, LabelCategories
)
from datumaro.components.extractors.tfrecord import (
DetectionApiExtractor,
)
from datumaro.components.converters.tfrecord import (
DetectionApiConverter,
)
from datumaro.util import find
from datumaro.util.test_utils import TestDir
class TfrecordConverterTest(TestCase):
def _test_can_save_and_load(self, source_dataset, converter, test_dir,
importer_params=None):
converter(source_dataset, test_dir.path)
if not importer_params:
importer_params = {}
project = Project.import_from(test_dir.path, 'tf_detection_api',
**importer_params)
parsed_dataset = project.make_dataset()
self.assertListEqual(
sorted(source_dataset.subsets()),
sorted(parsed_dataset.subsets()),
)
self.assertEqual(len(source_dataset), len(parsed_dataset))
for item_a in source_dataset:
item_b = find(parsed_dataset, lambda x: x.id == item_a.id)
self.assertFalse(item_b is None)
self.assertEqual(len(item_a.annotations), len(item_b.annotations))
for ann_a in item_a.annotations:
ann_b = find(item_b.annotations, lambda x: \
x.id == ann_a.id and \
x.type == ann_a.type and x.group == ann_a.group)
self.assertEqual(ann_a, ann_b, 'id: ' + str(ann_a.id))
def test_can_save_bboxes(self):
class TestExtractor(Extractor):
def __iter__(self):
items = [
DatasetItem(id=1, subset='train',
image=np.ones((16, 16, 3)),
annotations=[
BboxObject(0, 4, 4, 8, label=2, id=0),
BboxObject(0, 4, 4, 4, label=3, id=1),
BboxObject(2, 4, 4, 4, id=2),
]
),
DatasetItem(id=2, subset='val',
image=np.ones((8, 8, 3)),
annotations=[
BboxObject(1, 2, 4, 2, label=3, id=0),
]
),
DatasetItem(id=3, subset='test',
image=np.ones((5, 4, 3)) * 3,
),
]
for item in items:
yield item
def categories(self):
label_cat = LabelCategories()
for label in range(10):
label_cat.add('label_' + str(label))
return {
AnnotationType.label: label_cat,
}
with TestDir() as test_dir:
self._test_can_save_and_load(
TestExtractor(), DetectionApiConverter(save_images=True),
test_dir)
def test_can_save_dataset_with_no_subsets(self):
class TestExtractor(Extractor):
def __iter__(self):
items = [
DatasetItem(id=1,
image=np.ones((16, 16, 3)),
annotations=[
BboxObject(2, 1, 4, 4, label=2, id=0),
BboxObject(4, 2, 8, 4, label=3, id=1),
]
),
DatasetItem(id=2,
image=np.ones((8, 8, 3)) * 2,
annotations=[
BboxObject(4, 4, 4, 4, label=3, id=0),
]
),
DatasetItem(id=3,
image=np.ones((8, 4, 3)) * 3,
),
]
for item in items:
yield item
def categories(self):
label_cat = LabelCategories()
for label in range(10):
label_cat.add('label_' + str(label))
return {
AnnotationType.label: label_cat,
}
with TestDir() as test_dir:
self._test_can_save_and_load(
TestExtractor(), DetectionApiConverter(), test_dir)
def test_labelmap_parsing(self):
text = """
{
id: 4
name: 'qw1'
}
{
id: 5 name: 'qw2'
}
{
name: 'qw3'
id: 6
}
{name:'qw4' id:7}
"""
expected = {
'qw1': 4,
'qw2': 5,
'qw3': 6,
'qw4': 7,
}
parsed = DetectionApiExtractor._parse_labelmap(text)
self.assertEqual(expected, parsed)
Loading…
Cancel
Save