[Datumaro] Added tf detection api tfrecord import and export (#894)
* Added tf detection api tfrecord import and export * Added export button in dashboard * Add tf to requirements * Extend test * Add tf dependency * Require images in tfrecord export * Add video task case handling * Maintain image order in CVAT export * Fix Task image id-path conversions * Update tfrecord tests * Extend image utilities * Update tfrecord format * Fix image loading bug * Add some logs * Add 'copy' option to project import command * Reduce default cache sizemain
parent
310ee37b53
commit
0f5cf2064d
@ -0,0 +1,25 @@
|
|||||||
|
# Copyright (C) 2018 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from django.test import TestCase
|
||||||
|
from cvat.apps.engine.models import Task
|
||||||
|
|
||||||
|
|
||||||
|
class TaskModelTest(TestCase):
|
||||||
|
def test_frame_id_path_conversions(self):
|
||||||
|
task_id = 1
|
||||||
|
task = Task(task_id)
|
||||||
|
|
||||||
|
for i in [10 ** p for p in range(6)]:
|
||||||
|
src_path_expected = osp.join(
|
||||||
|
str(i // 10000), str(i // 100), '%s.jpg' % i)
|
||||||
|
src_path = task.get_frame_path(i)
|
||||||
|
|
||||||
|
dst_frame = task.get_image_frame(src_path)
|
||||||
|
|
||||||
|
self.assertTrue(src_path.endswith(src_path_expected),
|
||||||
|
'%s vs. %s' % (src_path, src_path_expected))
|
||||||
|
self.assertEqual(i, dst_frame)
|
||||||
@ -0,0 +1,146 @@
|
|||||||
|
|
||||||
|
# Copyright (C) 2019 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import codecs
|
||||||
|
from collections import OrderedDict
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import string
|
||||||
|
|
||||||
|
from datumaro.components.extractor import AnnotationType, DEFAULT_SUBSET_NAME
|
||||||
|
from datumaro.components.formats.tfrecord import DetectionApiPath
|
||||||
|
from datumaro.util.image import encode_image
|
||||||
|
from datumaro.util.tf_util import import_tf as _import_tf
|
||||||
|
|
||||||
|
|
||||||
|
# we need it to filter out non-ASCII characters, otherwise training will crash
|
||||||
|
_printable = set(string.printable)
|
||||||
|
def _make_printable(s):
|
||||||
|
return ''.join(filter(lambda x: x in _printable, s))
|
||||||
|
|
||||||
|
def _make_tf_example(item, get_label_id, get_label, save_images=False):
|
||||||
|
tf = _import_tf()
|
||||||
|
|
||||||
|
def int64_feature(value):
|
||||||
|
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
|
||||||
|
|
||||||
|
def int64_list_feature(value):
|
||||||
|
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
|
||||||
|
|
||||||
|
def bytes_feature(value):
|
||||||
|
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
||||||
|
|
||||||
|
def bytes_list_feature(value):
|
||||||
|
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
|
||||||
|
|
||||||
|
def float_list_feature(value):
|
||||||
|
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
|
||||||
|
|
||||||
|
|
||||||
|
features = {
|
||||||
|
'image/source_id': bytes_feature(str(item.id).encode('utf-8')),
|
||||||
|
'image/filename': bytes_feature(
|
||||||
|
('%s%s' % (item.id, DetectionApiPath.IMAGE_EXT)).encode('utf-8')),
|
||||||
|
}
|
||||||
|
|
||||||
|
if not item.has_image:
|
||||||
|
raise Exception(
|
||||||
|
"Failed to export dataset item '%s': item has no image" % item.id)
|
||||||
|
height, width, _ = item.image.shape
|
||||||
|
|
||||||
|
features.update({
|
||||||
|
'image/height': int64_feature(height),
|
||||||
|
'image/width': int64_feature(width),
|
||||||
|
})
|
||||||
|
|
||||||
|
if save_images and item.has_image:
|
||||||
|
fmt = DetectionApiPath.IMAGE_FORMAT
|
||||||
|
buffer = encode_image(item.image, DetectionApiPath.IMAGE_EXT)
|
||||||
|
|
||||||
|
features.update({
|
||||||
|
'image/encoded': bytes_feature(buffer),
|
||||||
|
'image/format': bytes_feature(fmt.encode('utf-8')),
|
||||||
|
})
|
||||||
|
|
||||||
|
xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
|
||||||
|
xmaxs = [] # List of normalized right x coordinates in bounding box (1 per box)
|
||||||
|
ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
|
||||||
|
ymaxs = [] # List of normalized bottom y coordinates in bounding box (1 per box)
|
||||||
|
classes_text = [] # List of string class name of bounding box (1 per box)
|
||||||
|
classes = [] # List of integer class id of bounding box (1 per box)
|
||||||
|
|
||||||
|
boxes = [ann for ann in item.annotations if ann.type is AnnotationType.bbox]
|
||||||
|
for box in boxes:
|
||||||
|
box_label = _make_printable(get_label(box.label))
|
||||||
|
|
||||||
|
xmins.append(box.points[0] / width)
|
||||||
|
xmaxs.append(box.points[2] / width)
|
||||||
|
ymins.append(box.points[1] / height)
|
||||||
|
ymaxs.append(box.points[3] / height)
|
||||||
|
classes_text.append(box_label.encode('utf-8'))
|
||||||
|
classes.append(get_label_id(box.label))
|
||||||
|
|
||||||
|
if boxes:
|
||||||
|
features.update({
|
||||||
|
'image/object/bbox/xmin': float_list_feature(xmins),
|
||||||
|
'image/object/bbox/xmax': float_list_feature(xmaxs),
|
||||||
|
'image/object/bbox/ymin': float_list_feature(ymins),
|
||||||
|
'image/object/bbox/ymax': float_list_feature(ymaxs),
|
||||||
|
'image/object/class/text': bytes_list_feature(classes_text),
|
||||||
|
'image/object/class/label': int64_list_feature(classes),
|
||||||
|
})
|
||||||
|
|
||||||
|
tf_example = tf.train.Example(
|
||||||
|
features=tf.train.Features(feature=features))
|
||||||
|
|
||||||
|
return tf_example
|
||||||
|
|
||||||
|
class DetectionApiConverter:
|
||||||
|
def __init__(self, save_images=True):
|
||||||
|
self.save_images = save_images
|
||||||
|
|
||||||
|
def __call__(self, extractor, save_dir):
|
||||||
|
tf = _import_tf()
|
||||||
|
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
subsets = extractor.subsets()
|
||||||
|
if len(subsets) == 0:
|
||||||
|
subsets = [ None ]
|
||||||
|
|
||||||
|
for subset_name in subsets:
|
||||||
|
if subset_name:
|
||||||
|
subset = extractor.get_subset(subset_name)
|
||||||
|
else:
|
||||||
|
subset_name = DEFAULT_SUBSET_NAME
|
||||||
|
subset = extractor
|
||||||
|
|
||||||
|
label_categories = subset.categories()[AnnotationType.label]
|
||||||
|
get_label = lambda label_id: label_categories.items[label_id].name \
|
||||||
|
if label_id is not None else ''
|
||||||
|
label_ids = OrderedDict((label.name, 1 + idx)
|
||||||
|
for idx, label in enumerate(label_categories.items))
|
||||||
|
map_label_id = lambda label_id: label_ids.get(get_label(label_id), 0)
|
||||||
|
|
||||||
|
labelmap_path = osp.join(save_dir, DetectionApiPath.LABELMAP_FILE)
|
||||||
|
with codecs.open(labelmap_path, 'w', encoding='utf8') as f:
|
||||||
|
for label, idx in label_ids.items():
|
||||||
|
f.write(
|
||||||
|
'item {\n' +
|
||||||
|
('\tid: %s\n' % (idx)) +
|
||||||
|
("\tname: '%s'\n" % (label)) +
|
||||||
|
'}\n\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
anno_path = osp.join(save_dir, '%s.tfrecord' % (subset_name))
|
||||||
|
with tf.io.TFRecordWriter(anno_path) as writer:
|
||||||
|
for item in subset:
|
||||||
|
tf_example = _make_tf_example(
|
||||||
|
item,
|
||||||
|
get_label=get_label,
|
||||||
|
get_label_id=map_label_id,
|
||||||
|
save_images=self.save_images,
|
||||||
|
)
|
||||||
|
writer.write(tf_example.SerializeToString())
|
||||||
@ -0,0 +1,206 @@
|
|||||||
|
|
||||||
|
# Copyright (C) 2019 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
import numpy as np
|
||||||
|
import os.path as osp
|
||||||
|
import re
|
||||||
|
|
||||||
|
from datumaro.components.extractor import AnnotationType, DEFAULT_SUBSET_NAME, \
|
||||||
|
LabelCategories, BboxObject, DatasetItem, Extractor
|
||||||
|
from datumaro.components.formats.tfrecord import DetectionApiPath
|
||||||
|
from datumaro.util.image import lazy_image, decode_image
|
||||||
|
from datumaro.util.tf_util import import_tf as _import_tf
|
||||||
|
|
||||||
|
|
||||||
|
def clamp(value, _min, _max):
|
||||||
|
return max(min(_max, value), _min)
|
||||||
|
|
||||||
|
class DetectionApiExtractor(Extractor):
|
||||||
|
class Subset(Extractor):
|
||||||
|
def __init__(self, name, parent):
|
||||||
|
super().__init__()
|
||||||
|
self._name = name
|
||||||
|
self._parent = parent
|
||||||
|
self.items = OrderedDict()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for item in self.items.values():
|
||||||
|
yield item
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.items)
|
||||||
|
|
||||||
|
def categories(self):
|
||||||
|
return self._parent.categories()
|
||||||
|
|
||||||
|
def __init__(self, path, images_dir=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
root_dir = osp.dirname(osp.abspath(path))
|
||||||
|
if osp.basename(root_dir) == DetectionApiPath.ANNOTATIONS_DIR:
|
||||||
|
root_dir = osp.dirname(root_dir)
|
||||||
|
images_dir = osp.join(root_dir, DetectionApiPath.IMAGES_DIR)
|
||||||
|
if not osp.isdir(images_dir):
|
||||||
|
images_dir = None
|
||||||
|
self._images_dir = images_dir
|
||||||
|
|
||||||
|
self._subsets = {}
|
||||||
|
|
||||||
|
subset_name = osp.splitext(osp.basename(path))[0]
|
||||||
|
if subset_name == DEFAULT_SUBSET_NAME:
|
||||||
|
subset_name = None
|
||||||
|
subset = DetectionApiExtractor.Subset(subset_name, self)
|
||||||
|
items, labels = self._parse_tfrecord_file(path, subset_name, images_dir)
|
||||||
|
subset.items = items
|
||||||
|
self._subsets[subset_name] = subset
|
||||||
|
|
||||||
|
label_categories = LabelCategories()
|
||||||
|
labels = sorted(labels.items(), key=lambda item: item[1])
|
||||||
|
for label, _ in labels:
|
||||||
|
label_categories.add(label)
|
||||||
|
self._categories = {
|
||||||
|
AnnotationType.label: label_categories
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_labelmap(cls, text):
|
||||||
|
id_pattern = r'(?:id\s*:\s*(?P<id>\d+))'
|
||||||
|
name_pattern = r'(?:name\s*:\s*[\'\"](?P<name>.*?)[\'\"])'
|
||||||
|
entry_pattern = r'(\{(?:[\s\n]*(?:%(id)s|%(name)s)[\s\n]*){2}\})+' % \
|
||||||
|
{'id': id_pattern, 'name': name_pattern}
|
||||||
|
matches = re.finditer(entry_pattern, text)
|
||||||
|
|
||||||
|
labelmap = {}
|
||||||
|
for match in matches:
|
||||||
|
label_id = match.group('id')
|
||||||
|
label_name = match.group('name')
|
||||||
|
if label_id is not None and label_name is not None:
|
||||||
|
labelmap[label_name] = int(label_id)
|
||||||
|
|
||||||
|
return labelmap
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_tfrecord_file(cls, filepath, subset_name, images_dir):
|
||||||
|
tf = _import_tf()
|
||||||
|
|
||||||
|
dataset = tf.data.TFRecordDataset(filepath)
|
||||||
|
features = {
|
||||||
|
'image/filename': tf.io.FixedLenFeature([], tf.string),
|
||||||
|
'image/source_id': tf.io.FixedLenFeature([], tf.string),
|
||||||
|
'image/height': tf.io.FixedLenFeature([], tf.int64),
|
||||||
|
'image/width': tf.io.FixedLenFeature([], tf.int64),
|
||||||
|
'image/encoded': tf.io.FixedLenFeature([], tf.string),
|
||||||
|
'image/format': tf.io.FixedLenFeature([], tf.string),
|
||||||
|
# Object boxes and classes.
|
||||||
|
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
|
||||||
|
'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
|
||||||
|
'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
|
||||||
|
'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
|
||||||
|
'image/object/class/label': tf.io.VarLenFeature(tf.int64),
|
||||||
|
'image/object/class/text': tf.io.VarLenFeature(tf.string),
|
||||||
|
'image/object/mask': tf.io.VarLenFeature(tf.string),
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset_labels = OrderedDict()
|
||||||
|
labelmap_path = osp.join(osp.dirname(filepath),
|
||||||
|
DetectionApiPath.LABELMAP_FILE)
|
||||||
|
if osp.exists(labelmap_path):
|
||||||
|
with open(labelmap_path, 'r', encoding='utf-8') as f:
|
||||||
|
labelmap_text = f.read()
|
||||||
|
dataset_labels.update({ label: id - 1
|
||||||
|
for label, id in cls._parse_labelmap(labelmap_text).items()
|
||||||
|
})
|
||||||
|
|
||||||
|
dataset_items = OrderedDict()
|
||||||
|
|
||||||
|
for record in dataset:
|
||||||
|
parsed_record = tf.io.parse_single_example(record, features)
|
||||||
|
frame_id = parsed_record['image/source_id'].numpy().decode('utf-8')
|
||||||
|
frame_filename = \
|
||||||
|
parsed_record['image/filename'].numpy().decode('utf-8')
|
||||||
|
frame_height = tf.cast(
|
||||||
|
parsed_record['image/height'], tf.int64).numpy().item()
|
||||||
|
frame_width = tf.cast(
|
||||||
|
parsed_record['image/width'], tf.int64).numpy().item()
|
||||||
|
frame_image = parsed_record['image/encoded'].numpy()
|
||||||
|
frame_format = parsed_record['image/format'].numpy().decode('utf-8')
|
||||||
|
xmins = tf.sparse.to_dense(
|
||||||
|
parsed_record['image/object/bbox/xmin']).numpy()
|
||||||
|
ymins = tf.sparse.to_dense(
|
||||||
|
parsed_record['image/object/bbox/ymin']).numpy()
|
||||||
|
xmaxs = tf.sparse.to_dense(
|
||||||
|
parsed_record['image/object/bbox/xmax']).numpy()
|
||||||
|
ymaxs = tf.sparse.to_dense(
|
||||||
|
parsed_record['image/object/bbox/ymax']).numpy()
|
||||||
|
label_ids = tf.sparse.to_dense(
|
||||||
|
parsed_record['image/object/class/label']).numpy()
|
||||||
|
labels = tf.sparse.to_dense(
|
||||||
|
parsed_record['image/object/class/text'],
|
||||||
|
default_value=b'').numpy()
|
||||||
|
|
||||||
|
for label, label_id in zip(labels, label_ids):
|
||||||
|
label = label.decode('utf-8')
|
||||||
|
if not label:
|
||||||
|
continue
|
||||||
|
if label_id <= 0:
|
||||||
|
continue
|
||||||
|
if label in dataset_labels:
|
||||||
|
continue
|
||||||
|
dataset_labels[label] = label_id - 1
|
||||||
|
|
||||||
|
item_id = frame_id
|
||||||
|
if not item_id:
|
||||||
|
item_id = osp.splitext(frame_filename)[0]
|
||||||
|
|
||||||
|
annotations = []
|
||||||
|
for index, shape in enumerate(
|
||||||
|
np.dstack((labels, xmins, ymins, xmaxs, ymaxs))[0]):
|
||||||
|
label = shape[0].decode('utf-8')
|
||||||
|
x = clamp(shape[1] * frame_width, 0, frame_width)
|
||||||
|
y = clamp(shape[2] * frame_height, 0, frame_height)
|
||||||
|
w = clamp(shape[3] * frame_width, 0, frame_width) - x
|
||||||
|
h = clamp(shape[4] * frame_height, 0, frame_height) - y
|
||||||
|
annotations.append(BboxObject(x, y, w, h,
|
||||||
|
label=dataset_labels.get(label, None), id=index
|
||||||
|
))
|
||||||
|
|
||||||
|
image = None
|
||||||
|
if image is None and frame_image and frame_format:
|
||||||
|
image = lazy_image(frame_image, loader=decode_image)
|
||||||
|
if image is None and frame_filename and images_dir:
|
||||||
|
image_path = osp.join(images_dir, frame_filename)
|
||||||
|
if osp.exists(image_path):
|
||||||
|
image = lazy_image(image_path)
|
||||||
|
|
||||||
|
dataset_items[item_id] = DatasetItem(id=item_id, subset=subset_name,
|
||||||
|
image=image, annotations=annotations)
|
||||||
|
|
||||||
|
return dataset_items, dataset_labels
|
||||||
|
|
||||||
|
def categories(self):
|
||||||
|
return self._categories
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for subset in self._subsets.values():
|
||||||
|
for item in subset:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
length = 0
|
||||||
|
for subset in self._subsets.values():
|
||||||
|
length += len(subset)
|
||||||
|
return length
|
||||||
|
|
||||||
|
def subsets(self):
|
||||||
|
return list(self._subsets)
|
||||||
|
|
||||||
|
def get_subset(self, name):
|
||||||
|
return self._subsets[name]
|
||||||
|
|
||||||
|
def get(self, item_id, subset=None, path=None):
|
||||||
|
if path is not None:
|
||||||
|
return None
|
||||||
|
return self.get_subset(subset).items.get(item_id, None)
|
||||||
@ -0,0 +1,13 @@
|
|||||||
|
|
||||||
|
# Copyright (C) 2019 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
class DetectionApiPath:
|
||||||
|
IMAGES_DIR = 'images'
|
||||||
|
ANNOTATIONS_DIR = 'annotations'
|
||||||
|
|
||||||
|
IMAGE_EXT = '.jpg'
|
||||||
|
IMAGE_FORMAT = 'jpeg'
|
||||||
|
|
||||||
|
LABELMAP_FILE = 'label_map.pbtxt'
|
||||||
@ -0,0 +1,35 @@
|
|||||||
|
|
||||||
|
# Copyright (C) 2019 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from glob import glob
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionApiImporter:
|
||||||
|
EXTRACTOR_NAME = 'tf_detection_api'
|
||||||
|
|
||||||
|
def __call__(self, path):
|
||||||
|
from datumaro.components.project import Project # cyclic import
|
||||||
|
project = Project()
|
||||||
|
|
||||||
|
subset_paths = glob(osp.join(path, '*.tfrecord'))
|
||||||
|
|
||||||
|
for subset_path in subset_paths:
|
||||||
|
if not osp.isfile(subset_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
subset_name = osp.splitext(osp.basename(subset_path))[0]
|
||||||
|
|
||||||
|
project.add_source(subset_name, {
|
||||||
|
'url': subset_path,
|
||||||
|
'format': self.EXTRACTOR_NAME,
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(project.config.sources) == 0:
|
||||||
|
raise Exception(
|
||||||
|
"Failed to find 'tf_detection_api' dataset at '%s'" % path)
|
||||||
|
|
||||||
|
return project
|
||||||
|
|
||||||
@ -0,0 +1,38 @@
|
|||||||
|
|
||||||
|
# Copyright (C) 2019 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
def import_tf():
|
||||||
|
import sys
|
||||||
|
|
||||||
|
tf = sys.modules.get('tensorflow', None)
|
||||||
|
if tf is not None:
|
||||||
|
return tf
|
||||||
|
|
||||||
|
# Reduce output noise, https://stackoverflow.com/questions/38073432/how-to-suppress-verbose-tensorflow-logging
|
||||||
|
import os
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
try:
|
||||||
|
tf.get_logger().setLevel('WARNING')
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.WARN)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Enable eager execution in early versions to unlock dataset operations
|
||||||
|
try:
|
||||||
|
tf.compat.v1.enable_eager_execution()
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
tf.enable_eager_execution()
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return tf
|
||||||
@ -0,0 +1,151 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
from datumaro.components.project import Project
|
||||||
|
from datumaro.components.extractor import (Extractor, DatasetItem,
|
||||||
|
AnnotationType, BboxObject, LabelCategories
|
||||||
|
)
|
||||||
|
from datumaro.components.extractors.tfrecord import (
|
||||||
|
DetectionApiExtractor,
|
||||||
|
)
|
||||||
|
from datumaro.components.converters.tfrecord import (
|
||||||
|
DetectionApiConverter,
|
||||||
|
)
|
||||||
|
from datumaro.util import find
|
||||||
|
from datumaro.util.test_utils import TestDir
|
||||||
|
|
||||||
|
|
||||||
|
class TfrecordConverterTest(TestCase):
|
||||||
|
def _test_can_save_and_load(self, source_dataset, converter, test_dir,
|
||||||
|
importer_params=None):
|
||||||
|
converter(source_dataset, test_dir.path)
|
||||||
|
|
||||||
|
if not importer_params:
|
||||||
|
importer_params = {}
|
||||||
|
project = Project.import_from(test_dir.path, 'tf_detection_api',
|
||||||
|
**importer_params)
|
||||||
|
parsed_dataset = project.make_dataset()
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
sorted(source_dataset.subsets()),
|
||||||
|
sorted(parsed_dataset.subsets()),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(source_dataset), len(parsed_dataset))
|
||||||
|
|
||||||
|
for item_a in source_dataset:
|
||||||
|
item_b = find(parsed_dataset, lambda x: x.id == item_a.id)
|
||||||
|
self.assertFalse(item_b is None)
|
||||||
|
self.assertEqual(len(item_a.annotations), len(item_b.annotations))
|
||||||
|
for ann_a in item_a.annotations:
|
||||||
|
ann_b = find(item_b.annotations, lambda x: \
|
||||||
|
x.id == ann_a.id and \
|
||||||
|
x.type == ann_a.type and x.group == ann_a.group)
|
||||||
|
self.assertEqual(ann_a, ann_b, 'id: ' + str(ann_a.id))
|
||||||
|
|
||||||
|
def test_can_save_bboxes(self):
|
||||||
|
class TestExtractor(Extractor):
|
||||||
|
def __iter__(self):
|
||||||
|
items = [
|
||||||
|
DatasetItem(id=1, subset='train',
|
||||||
|
image=np.ones((16, 16, 3)),
|
||||||
|
annotations=[
|
||||||
|
BboxObject(0, 4, 4, 8, label=2, id=0),
|
||||||
|
BboxObject(0, 4, 4, 4, label=3, id=1),
|
||||||
|
BboxObject(2, 4, 4, 4, id=2),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
|
||||||
|
DatasetItem(id=2, subset='val',
|
||||||
|
image=np.ones((8, 8, 3)),
|
||||||
|
annotations=[
|
||||||
|
BboxObject(1, 2, 4, 2, label=3, id=0),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
|
||||||
|
DatasetItem(id=3, subset='test',
|
||||||
|
image=np.ones((5, 4, 3)) * 3,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
def categories(self):
|
||||||
|
label_cat = LabelCategories()
|
||||||
|
for label in range(10):
|
||||||
|
label_cat.add('label_' + str(label))
|
||||||
|
return {
|
||||||
|
AnnotationType.label: label_cat,
|
||||||
|
}
|
||||||
|
|
||||||
|
with TestDir() as test_dir:
|
||||||
|
self._test_can_save_and_load(
|
||||||
|
TestExtractor(), DetectionApiConverter(save_images=True),
|
||||||
|
test_dir)
|
||||||
|
|
||||||
|
def test_can_save_dataset_with_no_subsets(self):
|
||||||
|
class TestExtractor(Extractor):
|
||||||
|
def __iter__(self):
|
||||||
|
items = [
|
||||||
|
DatasetItem(id=1,
|
||||||
|
image=np.ones((16, 16, 3)),
|
||||||
|
annotations=[
|
||||||
|
BboxObject(2, 1, 4, 4, label=2, id=0),
|
||||||
|
BboxObject(4, 2, 8, 4, label=3, id=1),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
|
||||||
|
DatasetItem(id=2,
|
||||||
|
image=np.ones((8, 8, 3)) * 2,
|
||||||
|
annotations=[
|
||||||
|
BboxObject(4, 4, 4, 4, label=3, id=0),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
|
||||||
|
DatasetItem(id=3,
|
||||||
|
image=np.ones((8, 4, 3)) * 3,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
def categories(self):
|
||||||
|
label_cat = LabelCategories()
|
||||||
|
for label in range(10):
|
||||||
|
label_cat.add('label_' + str(label))
|
||||||
|
return {
|
||||||
|
AnnotationType.label: label_cat,
|
||||||
|
}
|
||||||
|
|
||||||
|
with TestDir() as test_dir:
|
||||||
|
self._test_can_save_and_load(
|
||||||
|
TestExtractor(), DetectionApiConverter(), test_dir)
|
||||||
|
|
||||||
|
def test_labelmap_parsing(self):
|
||||||
|
text = """
|
||||||
|
{
|
||||||
|
id: 4
|
||||||
|
name: 'qw1'
|
||||||
|
}
|
||||||
|
{
|
||||||
|
id: 5 name: 'qw2'
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
name: 'qw3'
|
||||||
|
id: 6
|
||||||
|
}
|
||||||
|
{name:'qw4' id:7}
|
||||||
|
"""
|
||||||
|
expected = {
|
||||||
|
'qw1': 4,
|
||||||
|
'qw2': 5,
|
||||||
|
'qw3': 6,
|
||||||
|
'qw4': 7,
|
||||||
|
}
|
||||||
|
parsed = DetectionApiExtractor._parse_labelmap(text)
|
||||||
|
|
||||||
|
self.assertEqual(expected, parsed)
|
||||||
Loading…
Reference in New Issue