[Datumaro] Added tf detection api tfrecord import and export (#894)
* Added tf detection api tfrecord import and export * Added export button in dashboard * Add tf to requirements * Extend test * Add tf dependency * Require images in tfrecord export * Add video task case handling * Maintain image order in CVAT export * Fix Task image id-path conversions * Update tfrecord tests * Extend image utilities * Update tfrecord format * Fix image loading bug * Add some logs * Add 'copy' option to project import command * Reduce default cache sizemain
parent
310ee37b53
commit
0f5cf2064d
@ -0,0 +1,25 @@
|
||||
# Copyright (C) 2018 Intel Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os.path as osp
|
||||
|
||||
from django.test import TestCase
|
||||
from cvat.apps.engine.models import Task
|
||||
|
||||
|
||||
class TaskModelTest(TestCase):
|
||||
def test_frame_id_path_conversions(self):
|
||||
task_id = 1
|
||||
task = Task(task_id)
|
||||
|
||||
for i in [10 ** p for p in range(6)]:
|
||||
src_path_expected = osp.join(
|
||||
str(i // 10000), str(i // 100), '%s.jpg' % i)
|
||||
src_path = task.get_frame_path(i)
|
||||
|
||||
dst_frame = task.get_image_frame(src_path)
|
||||
|
||||
self.assertTrue(src_path.endswith(src_path_expected),
|
||||
'%s vs. %s' % (src_path, src_path_expected))
|
||||
self.assertEqual(i, dst_frame)
|
||||
@ -0,0 +1,146 @@
|
||||
|
||||
# Copyright (C) 2019 Intel Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import codecs
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import os.path as osp
|
||||
import string
|
||||
|
||||
from datumaro.components.extractor import AnnotationType, DEFAULT_SUBSET_NAME
|
||||
from datumaro.components.formats.tfrecord import DetectionApiPath
|
||||
from datumaro.util.image import encode_image
|
||||
from datumaro.util.tf_util import import_tf as _import_tf
|
||||
|
||||
|
||||
# we need it to filter out non-ASCII characters, otherwise training will crash
|
||||
_printable = set(string.printable)
|
||||
def _make_printable(s):
|
||||
return ''.join(filter(lambda x: x in _printable, s))
|
||||
|
||||
def _make_tf_example(item, get_label_id, get_label, save_images=False):
|
||||
tf = _import_tf()
|
||||
|
||||
def int64_feature(value):
|
||||
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
|
||||
|
||||
def int64_list_feature(value):
|
||||
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
|
||||
|
||||
def bytes_feature(value):
|
||||
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
||||
|
||||
def bytes_list_feature(value):
|
||||
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
|
||||
|
||||
def float_list_feature(value):
|
||||
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
|
||||
|
||||
|
||||
features = {
|
||||
'image/source_id': bytes_feature(str(item.id).encode('utf-8')),
|
||||
'image/filename': bytes_feature(
|
||||
('%s%s' % (item.id, DetectionApiPath.IMAGE_EXT)).encode('utf-8')),
|
||||
}
|
||||
|
||||
if not item.has_image:
|
||||
raise Exception(
|
||||
"Failed to export dataset item '%s': item has no image" % item.id)
|
||||
height, width, _ = item.image.shape
|
||||
|
||||
features.update({
|
||||
'image/height': int64_feature(height),
|
||||
'image/width': int64_feature(width),
|
||||
})
|
||||
|
||||
if save_images and item.has_image:
|
||||
fmt = DetectionApiPath.IMAGE_FORMAT
|
||||
buffer = encode_image(item.image, DetectionApiPath.IMAGE_EXT)
|
||||
|
||||
features.update({
|
||||
'image/encoded': bytes_feature(buffer),
|
||||
'image/format': bytes_feature(fmt.encode('utf-8')),
|
||||
})
|
||||
|
||||
xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
|
||||
xmaxs = [] # List of normalized right x coordinates in bounding box (1 per box)
|
||||
ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
|
||||
ymaxs = [] # List of normalized bottom y coordinates in bounding box (1 per box)
|
||||
classes_text = [] # List of string class name of bounding box (1 per box)
|
||||
classes = [] # List of integer class id of bounding box (1 per box)
|
||||
|
||||
boxes = [ann for ann in item.annotations if ann.type is AnnotationType.bbox]
|
||||
for box in boxes:
|
||||
box_label = _make_printable(get_label(box.label))
|
||||
|
||||
xmins.append(box.points[0] / width)
|
||||
xmaxs.append(box.points[2] / width)
|
||||
ymins.append(box.points[1] / height)
|
||||
ymaxs.append(box.points[3] / height)
|
||||
classes_text.append(box_label.encode('utf-8'))
|
||||
classes.append(get_label_id(box.label))
|
||||
|
||||
if boxes:
|
||||
features.update({
|
||||
'image/object/bbox/xmin': float_list_feature(xmins),
|
||||
'image/object/bbox/xmax': float_list_feature(xmaxs),
|
||||
'image/object/bbox/ymin': float_list_feature(ymins),
|
||||
'image/object/bbox/ymax': float_list_feature(ymaxs),
|
||||
'image/object/class/text': bytes_list_feature(classes_text),
|
||||
'image/object/class/label': int64_list_feature(classes),
|
||||
})
|
||||
|
||||
tf_example = tf.train.Example(
|
||||
features=tf.train.Features(feature=features))
|
||||
|
||||
return tf_example
|
||||
|
||||
class DetectionApiConverter:
|
||||
def __init__(self, save_images=True):
|
||||
self.save_images = save_images
|
||||
|
||||
def __call__(self, extractor, save_dir):
|
||||
tf = _import_tf()
|
||||
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
subsets = extractor.subsets()
|
||||
if len(subsets) == 0:
|
||||
subsets = [ None ]
|
||||
|
||||
for subset_name in subsets:
|
||||
if subset_name:
|
||||
subset = extractor.get_subset(subset_name)
|
||||
else:
|
||||
subset_name = DEFAULT_SUBSET_NAME
|
||||
subset = extractor
|
||||
|
||||
label_categories = subset.categories()[AnnotationType.label]
|
||||
get_label = lambda label_id: label_categories.items[label_id].name \
|
||||
if label_id is not None else ''
|
||||
label_ids = OrderedDict((label.name, 1 + idx)
|
||||
for idx, label in enumerate(label_categories.items))
|
||||
map_label_id = lambda label_id: label_ids.get(get_label(label_id), 0)
|
||||
|
||||
labelmap_path = osp.join(save_dir, DetectionApiPath.LABELMAP_FILE)
|
||||
with codecs.open(labelmap_path, 'w', encoding='utf8') as f:
|
||||
for label, idx in label_ids.items():
|
||||
f.write(
|
||||
'item {\n' +
|
||||
('\tid: %s\n' % (idx)) +
|
||||
("\tname: '%s'\n" % (label)) +
|
||||
'}\n\n'
|
||||
)
|
||||
|
||||
anno_path = osp.join(save_dir, '%s.tfrecord' % (subset_name))
|
||||
with tf.io.TFRecordWriter(anno_path) as writer:
|
||||
for item in subset:
|
||||
tf_example = _make_tf_example(
|
||||
item,
|
||||
get_label=get_label,
|
||||
get_label_id=map_label_id,
|
||||
save_images=self.save_images,
|
||||
)
|
||||
writer.write(tf_example.SerializeToString())
|
||||
@ -0,0 +1,206 @@
|
||||
|
||||
# Copyright (C) 2019 Intel Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
import re
|
||||
|
||||
from datumaro.components.extractor import AnnotationType, DEFAULT_SUBSET_NAME, \
|
||||
LabelCategories, BboxObject, DatasetItem, Extractor
|
||||
from datumaro.components.formats.tfrecord import DetectionApiPath
|
||||
from datumaro.util.image import lazy_image, decode_image
|
||||
from datumaro.util.tf_util import import_tf as _import_tf
|
||||
|
||||
|
||||
def clamp(value, _min, _max):
|
||||
return max(min(_max, value), _min)
|
||||
|
||||
class DetectionApiExtractor(Extractor):
|
||||
class Subset(Extractor):
|
||||
def __init__(self, name, parent):
|
||||
super().__init__()
|
||||
self._name = name
|
||||
self._parent = parent
|
||||
self.items = OrderedDict()
|
||||
|
||||
def __iter__(self):
|
||||
for item in self.items.values():
|
||||
yield item
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def categories(self):
|
||||
return self._parent.categories()
|
||||
|
||||
def __init__(self, path, images_dir=None):
|
||||
super().__init__()
|
||||
|
||||
root_dir = osp.dirname(osp.abspath(path))
|
||||
if osp.basename(root_dir) == DetectionApiPath.ANNOTATIONS_DIR:
|
||||
root_dir = osp.dirname(root_dir)
|
||||
images_dir = osp.join(root_dir, DetectionApiPath.IMAGES_DIR)
|
||||
if not osp.isdir(images_dir):
|
||||
images_dir = None
|
||||
self._images_dir = images_dir
|
||||
|
||||
self._subsets = {}
|
||||
|
||||
subset_name = osp.splitext(osp.basename(path))[0]
|
||||
if subset_name == DEFAULT_SUBSET_NAME:
|
||||
subset_name = None
|
||||
subset = DetectionApiExtractor.Subset(subset_name, self)
|
||||
items, labels = self._parse_tfrecord_file(path, subset_name, images_dir)
|
||||
subset.items = items
|
||||
self._subsets[subset_name] = subset
|
||||
|
||||
label_categories = LabelCategories()
|
||||
labels = sorted(labels.items(), key=lambda item: item[1])
|
||||
for label, _ in labels:
|
||||
label_categories.add(label)
|
||||
self._categories = {
|
||||
AnnotationType.label: label_categories
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _parse_labelmap(cls, text):
|
||||
id_pattern = r'(?:id\s*:\s*(?P<id>\d+))'
|
||||
name_pattern = r'(?:name\s*:\s*[\'\"](?P<name>.*?)[\'\"])'
|
||||
entry_pattern = r'(\{(?:[\s\n]*(?:%(id)s|%(name)s)[\s\n]*){2}\})+' % \
|
||||
{'id': id_pattern, 'name': name_pattern}
|
||||
matches = re.finditer(entry_pattern, text)
|
||||
|
||||
labelmap = {}
|
||||
for match in matches:
|
||||
label_id = match.group('id')
|
||||
label_name = match.group('name')
|
||||
if label_id is not None and label_name is not None:
|
||||
labelmap[label_name] = int(label_id)
|
||||
|
||||
return labelmap
|
||||
|
||||
@classmethod
|
||||
def _parse_tfrecord_file(cls, filepath, subset_name, images_dir):
|
||||
tf = _import_tf()
|
||||
|
||||
dataset = tf.data.TFRecordDataset(filepath)
|
||||
features = {
|
||||
'image/filename': tf.io.FixedLenFeature([], tf.string),
|
||||
'image/source_id': tf.io.FixedLenFeature([], tf.string),
|
||||
'image/height': tf.io.FixedLenFeature([], tf.int64),
|
||||
'image/width': tf.io.FixedLenFeature([], tf.int64),
|
||||
'image/encoded': tf.io.FixedLenFeature([], tf.string),
|
||||
'image/format': tf.io.FixedLenFeature([], tf.string),
|
||||
# Object boxes and classes.
|
||||
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
|
||||
'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
|
||||
'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
|
||||
'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
|
||||
'image/object/class/label': tf.io.VarLenFeature(tf.int64),
|
||||
'image/object/class/text': tf.io.VarLenFeature(tf.string),
|
||||
'image/object/mask': tf.io.VarLenFeature(tf.string),
|
||||
}
|
||||
|
||||
dataset_labels = OrderedDict()
|
||||
labelmap_path = osp.join(osp.dirname(filepath),
|
||||
DetectionApiPath.LABELMAP_FILE)
|
||||
if osp.exists(labelmap_path):
|
||||
with open(labelmap_path, 'r', encoding='utf-8') as f:
|
||||
labelmap_text = f.read()
|
||||
dataset_labels.update({ label: id - 1
|
||||
for label, id in cls._parse_labelmap(labelmap_text).items()
|
||||
})
|
||||
|
||||
dataset_items = OrderedDict()
|
||||
|
||||
for record in dataset:
|
||||
parsed_record = tf.io.parse_single_example(record, features)
|
||||
frame_id = parsed_record['image/source_id'].numpy().decode('utf-8')
|
||||
frame_filename = \
|
||||
parsed_record['image/filename'].numpy().decode('utf-8')
|
||||
frame_height = tf.cast(
|
||||
parsed_record['image/height'], tf.int64).numpy().item()
|
||||
frame_width = tf.cast(
|
||||
parsed_record['image/width'], tf.int64).numpy().item()
|
||||
frame_image = parsed_record['image/encoded'].numpy()
|
||||
frame_format = parsed_record['image/format'].numpy().decode('utf-8')
|
||||
xmins = tf.sparse.to_dense(
|
||||
parsed_record['image/object/bbox/xmin']).numpy()
|
||||
ymins = tf.sparse.to_dense(
|
||||
parsed_record['image/object/bbox/ymin']).numpy()
|
||||
xmaxs = tf.sparse.to_dense(
|
||||
parsed_record['image/object/bbox/xmax']).numpy()
|
||||
ymaxs = tf.sparse.to_dense(
|
||||
parsed_record['image/object/bbox/ymax']).numpy()
|
||||
label_ids = tf.sparse.to_dense(
|
||||
parsed_record['image/object/class/label']).numpy()
|
||||
labels = tf.sparse.to_dense(
|
||||
parsed_record['image/object/class/text'],
|
||||
default_value=b'').numpy()
|
||||
|
||||
for label, label_id in zip(labels, label_ids):
|
||||
label = label.decode('utf-8')
|
||||
if not label:
|
||||
continue
|
||||
if label_id <= 0:
|
||||
continue
|
||||
if label in dataset_labels:
|
||||
continue
|
||||
dataset_labels[label] = label_id - 1
|
||||
|
||||
item_id = frame_id
|
||||
if not item_id:
|
||||
item_id = osp.splitext(frame_filename)[0]
|
||||
|
||||
annotations = []
|
||||
for index, shape in enumerate(
|
||||
np.dstack((labels, xmins, ymins, xmaxs, ymaxs))[0]):
|
||||
label = shape[0].decode('utf-8')
|
||||
x = clamp(shape[1] * frame_width, 0, frame_width)
|
||||
y = clamp(shape[2] * frame_height, 0, frame_height)
|
||||
w = clamp(shape[3] * frame_width, 0, frame_width) - x
|
||||
h = clamp(shape[4] * frame_height, 0, frame_height) - y
|
||||
annotations.append(BboxObject(x, y, w, h,
|
||||
label=dataset_labels.get(label, None), id=index
|
||||
))
|
||||
|
||||
image = None
|
||||
if image is None and frame_image and frame_format:
|
||||
image = lazy_image(frame_image, loader=decode_image)
|
||||
if image is None and frame_filename and images_dir:
|
||||
image_path = osp.join(images_dir, frame_filename)
|
||||
if osp.exists(image_path):
|
||||
image = lazy_image(image_path)
|
||||
|
||||
dataset_items[item_id] = DatasetItem(id=item_id, subset=subset_name,
|
||||
image=image, annotations=annotations)
|
||||
|
||||
return dataset_items, dataset_labels
|
||||
|
||||
def categories(self):
|
||||
return self._categories
|
||||
|
||||
def __iter__(self):
|
||||
for subset in self._subsets.values():
|
||||
for item in subset:
|
||||
yield item
|
||||
|
||||
def __len__(self):
|
||||
length = 0
|
||||
for subset in self._subsets.values():
|
||||
length += len(subset)
|
||||
return length
|
||||
|
||||
def subsets(self):
|
||||
return list(self._subsets)
|
||||
|
||||
def get_subset(self, name):
|
||||
return self._subsets[name]
|
||||
|
||||
def get(self, item_id, subset=None, path=None):
|
||||
if path is not None:
|
||||
return None
|
||||
return self.get_subset(subset).items.get(item_id, None)
|
||||
@ -0,0 +1,13 @@
|
||||
|
||||
# Copyright (C) 2019 Intel Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
class DetectionApiPath:
|
||||
IMAGES_DIR = 'images'
|
||||
ANNOTATIONS_DIR = 'annotations'
|
||||
|
||||
IMAGE_EXT = '.jpg'
|
||||
IMAGE_FORMAT = 'jpeg'
|
||||
|
||||
LABELMAP_FILE = 'label_map.pbtxt'
|
||||
@ -0,0 +1,35 @@
|
||||
|
||||
# Copyright (C) 2019 Intel Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from glob import glob
|
||||
import os.path as osp
|
||||
|
||||
|
||||
class DetectionApiImporter:
|
||||
EXTRACTOR_NAME = 'tf_detection_api'
|
||||
|
||||
def __call__(self, path):
|
||||
from datumaro.components.project import Project # cyclic import
|
||||
project = Project()
|
||||
|
||||
subset_paths = glob(osp.join(path, '*.tfrecord'))
|
||||
|
||||
for subset_path in subset_paths:
|
||||
if not osp.isfile(subset_path):
|
||||
continue
|
||||
|
||||
subset_name = osp.splitext(osp.basename(subset_path))[0]
|
||||
|
||||
project.add_source(subset_name, {
|
||||
'url': subset_path,
|
||||
'format': self.EXTRACTOR_NAME,
|
||||
})
|
||||
|
||||
if len(project.config.sources) == 0:
|
||||
raise Exception(
|
||||
"Failed to find 'tf_detection_api' dataset at '%s'" % path)
|
||||
|
||||
return project
|
||||
|
||||
@ -0,0 +1,38 @@
|
||||
|
||||
# Copyright (C) 2019 Intel Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
def import_tf():
|
||||
import sys
|
||||
|
||||
tf = sys.modules.get('tensorflow', None)
|
||||
if tf is not None:
|
||||
return tf
|
||||
|
||||
# Reduce output noise, https://stackoverflow.com/questions/38073432/how-to-suppress-verbose-tensorflow-logging
|
||||
import os
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
try:
|
||||
tf.get_logger().setLevel('WARNING')
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.WARN)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Enable eager execution in early versions to unlock dataset operations
|
||||
try:
|
||||
tf.compat.v1.enable_eager_execution()
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
tf.enable_eager_execution()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return tf
|
||||
@ -0,0 +1,151 @@
|
||||
import numpy as np
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
from datumaro.components.project import Project
|
||||
from datumaro.components.extractor import (Extractor, DatasetItem,
|
||||
AnnotationType, BboxObject, LabelCategories
|
||||
)
|
||||
from datumaro.components.extractors.tfrecord import (
|
||||
DetectionApiExtractor,
|
||||
)
|
||||
from datumaro.components.converters.tfrecord import (
|
||||
DetectionApiConverter,
|
||||
)
|
||||
from datumaro.util import find
|
||||
from datumaro.util.test_utils import TestDir
|
||||
|
||||
|
||||
class TfrecordConverterTest(TestCase):
|
||||
def _test_can_save_and_load(self, source_dataset, converter, test_dir,
|
||||
importer_params=None):
|
||||
converter(source_dataset, test_dir.path)
|
||||
|
||||
if not importer_params:
|
||||
importer_params = {}
|
||||
project = Project.import_from(test_dir.path, 'tf_detection_api',
|
||||
**importer_params)
|
||||
parsed_dataset = project.make_dataset()
|
||||
|
||||
self.assertListEqual(
|
||||
sorted(source_dataset.subsets()),
|
||||
sorted(parsed_dataset.subsets()),
|
||||
)
|
||||
|
||||
self.assertEqual(len(source_dataset), len(parsed_dataset))
|
||||
|
||||
for item_a in source_dataset:
|
||||
item_b = find(parsed_dataset, lambda x: x.id == item_a.id)
|
||||
self.assertFalse(item_b is None)
|
||||
self.assertEqual(len(item_a.annotations), len(item_b.annotations))
|
||||
for ann_a in item_a.annotations:
|
||||
ann_b = find(item_b.annotations, lambda x: \
|
||||
x.id == ann_a.id and \
|
||||
x.type == ann_a.type and x.group == ann_a.group)
|
||||
self.assertEqual(ann_a, ann_b, 'id: ' + str(ann_a.id))
|
||||
|
||||
def test_can_save_bboxes(self):
|
||||
class TestExtractor(Extractor):
|
||||
def __iter__(self):
|
||||
items = [
|
||||
DatasetItem(id=1, subset='train',
|
||||
image=np.ones((16, 16, 3)),
|
||||
annotations=[
|
||||
BboxObject(0, 4, 4, 8, label=2, id=0),
|
||||
BboxObject(0, 4, 4, 4, label=3, id=1),
|
||||
BboxObject(2, 4, 4, 4, id=2),
|
||||
]
|
||||
),
|
||||
|
||||
DatasetItem(id=2, subset='val',
|
||||
image=np.ones((8, 8, 3)),
|
||||
annotations=[
|
||||
BboxObject(1, 2, 4, 2, label=3, id=0),
|
||||
]
|
||||
),
|
||||
|
||||
DatasetItem(id=3, subset='test',
|
||||
image=np.ones((5, 4, 3)) * 3,
|
||||
),
|
||||
]
|
||||
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
def categories(self):
|
||||
label_cat = LabelCategories()
|
||||
for label in range(10):
|
||||
label_cat.add('label_' + str(label))
|
||||
return {
|
||||
AnnotationType.label: label_cat,
|
||||
}
|
||||
|
||||
with TestDir() as test_dir:
|
||||
self._test_can_save_and_load(
|
||||
TestExtractor(), DetectionApiConverter(save_images=True),
|
||||
test_dir)
|
||||
|
||||
def test_can_save_dataset_with_no_subsets(self):
|
||||
class TestExtractor(Extractor):
|
||||
def __iter__(self):
|
||||
items = [
|
||||
DatasetItem(id=1,
|
||||
image=np.ones((16, 16, 3)),
|
||||
annotations=[
|
||||
BboxObject(2, 1, 4, 4, label=2, id=0),
|
||||
BboxObject(4, 2, 8, 4, label=3, id=1),
|
||||
]
|
||||
),
|
||||
|
||||
DatasetItem(id=2,
|
||||
image=np.ones((8, 8, 3)) * 2,
|
||||
annotations=[
|
||||
BboxObject(4, 4, 4, 4, label=3, id=0),
|
||||
]
|
||||
),
|
||||
|
||||
DatasetItem(id=3,
|
||||
image=np.ones((8, 4, 3)) * 3,
|
||||
),
|
||||
]
|
||||
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
def categories(self):
|
||||
label_cat = LabelCategories()
|
||||
for label in range(10):
|
||||
label_cat.add('label_' + str(label))
|
||||
return {
|
||||
AnnotationType.label: label_cat,
|
||||
}
|
||||
|
||||
with TestDir() as test_dir:
|
||||
self._test_can_save_and_load(
|
||||
TestExtractor(), DetectionApiConverter(), test_dir)
|
||||
|
||||
def test_labelmap_parsing(self):
|
||||
text = """
|
||||
{
|
||||
id: 4
|
||||
name: 'qw1'
|
||||
}
|
||||
{
|
||||
id: 5 name: 'qw2'
|
||||
}
|
||||
|
||||
{
|
||||
name: 'qw3'
|
||||
id: 6
|
||||
}
|
||||
{name:'qw4' id:7}
|
||||
"""
|
||||
expected = {
|
||||
'qw1': 4,
|
||||
'qw2': 5,
|
||||
'qw3': 6,
|
||||
'qw4': 7,
|
||||
}
|
||||
parsed = DetectionApiExtractor._parse_labelmap(text)
|
||||
|
||||
self.assertEqual(expected, parsed)
|
||||
Loading…
Reference in New Issue