[Datumaro] Add DatasetItem attributes (#1639)

* Add DatasetItem attributes

* Update tests

* Update datumaro format
main
zhiltsov-max 6 years ago committed by GitHub
parent b2503c620f
commit ba309c882f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -576,7 +576,7 @@ class Caption(Annotation):
class DatasetItem:
# pylint: disable=redefined-builtin
def __init__(self, id=None, annotations=None,
subset=None, path=None, image=None):
subset=None, path=None, image=None, attributes=None):
assert id is not None
self._id = str(id)
@ -604,6 +604,12 @@ class DatasetItem:
image = Image(path=image)
assert image is None or isinstance(image, Image)
self._image = image
if attributes is None:
attributes = {}
else:
attributes = dict(attributes)
self._attributes = attributes
# pylint: enable=redefined-builtin
@property
@ -630,6 +636,10 @@ class DatasetItem:
def has_image(self):
return self._image is not None
@property
def attributes(self):
return self._attributes
def __eq__(self, other):
if not isinstance(other, __class__):
return False
@ -638,10 +648,12 @@ class DatasetItem:
(self.subset == other.subset) and \
(self.path == other.path) and \
(self.annotations == other.annotations) and \
(self.image == other.image)
(self.image == other.image) and \
(self.attributes == other.attributes)
def wrap(item, **kwargs):
expected_args = {'id', 'annotations', 'subset', 'path', 'image'}
expected_args = {'id', 'annotations', 'subset',
'path', 'image', 'attributes'}
for k in expected_args:
if k not in kwargs:
kwargs[k] = getattr(item, k)

@ -49,6 +49,8 @@ class _SubsetWriter:
'id': item.id,
'annotations': annotations,
}
if item.attributes:
item_desc['attr'] = item.attributes
if item.path:
item_desc['path'] = item.path
if item.has_image:

@ -91,13 +91,15 @@ class DatumaroExtractor(SourceExtractor):
annotations = self._load_annotations(item_desc)
item = DatasetItem(id=item_id, subset=self._subset,
annotations=annotations, image=image)
annotations=annotations, image=image,
attributes=item_desc.get('attr'))
items.append(item)
return items
def _load_annotations(self, item):
@staticmethod
def _load_annotations(item):
parsed = item['annotations']
loaded = []

@ -97,4 +97,21 @@ def compare_datasets(test, expected, actual):
ann_b = find(ann_b_matches, lambda x: x == ann_a)
test.assertEqual(ann_a, ann_b, 'ann: %s' % ann_to_str(ann_a))
item_b.annotations.remove(ann_b) # avoid repeats
item_b.annotations.remove(ann_b) # avoid repeats
def compare_datasets_strict(test, expected, actual):
# Compares datasets for strong equality
test.assertEqual(expected.categories(), actual.categories())
test.assertListEqual(sorted(expected.subsets()), sorted(actual.subsets()))
test.assertEqual(len(expected), len(actual))
for subset_name in expected.subsets():
e_subset = expected.get_subset(subset_name)
a_subset = actual.get_subset(subset_name)
test.assertEqual(len(e_subset), len(a_subset))
for idx, (item_a, item_b) in enumerate(zip(e_subset, a_subset)):
test.assertEqual(item_a, item_b,
'%s:\n%s\nvs.\n%s\n' % \
(idx, item_to_str(item_a), item_to_str(item_b)))

@ -12,10 +12,24 @@ from datumaro.plugins.datumaro_format.importer import DatumaroImporter
from datumaro.plugins.datumaro_format.converter import DatumaroConverter
from datumaro.util.mask_tools import generate_colormap
from datumaro.util.image import Image
from datumaro.util.test_utils import TestDir, item_to_str
from datumaro.util.test_utils import TestDir, compare_datasets_strict
class DatumaroConverterTest(TestCase):
def _test_save_and_load(self, source_dataset, converter, test_dir,
target_dataset=None, importer_args=None):
converter(source_dataset, test_dir)
if importer_args is None:
importer_args = {}
parsed_dataset = Project.import_from(
test_dir, 'datumaro', **importer_args).make_dataset()
if target_dataset is None:
target_dataset = source_dataset
compare_datasets_strict(self,
expected=target_dataset, actual=parsed_dataset)
class TestExtractor(Extractor):
def __iter__(self):
return iter([
@ -47,7 +61,8 @@ class DatumaroConverterTest(TestCase):
Polygon([1, 2, 3, 4, 5, 6, 7, 8], id=12, z_order=4),
]),
DatasetItem(id=42, subset='test'),
DatasetItem(id=42, subset='test',
attributes={'a1': 5, 'a2': '42'}),
DatasetItem(id=42),
DatasetItem(id=43, image=Image(path='1/b/c.qq', size=(2, 4))),
@ -73,36 +88,11 @@ class DatumaroConverterTest(TestCase):
def test_can_save_and_load(self):
with TestDir() as test_dir:
source_dataset = self.TestExtractor()
converter = DatumaroConverter(save_images=True)
converter(source_dataset, test_dir)
project = Project.import_from(test_dir, 'datumaro')
parsed_dataset = project.make_dataset()
self.assertListEqual(
sorted(source_dataset.subsets()),
sorted(parsed_dataset.subsets()),
)
self.assertEqual(len(source_dataset), len(parsed_dataset))
for subset_name in source_dataset.subsets():
source_subset = source_dataset.get_subset(subset_name)
parsed_subset = parsed_dataset.get_subset(subset_name)
self.assertEqual(len(source_subset), len(parsed_subset))
for idx, (item_a, item_b) in enumerate(
zip(source_subset, parsed_subset)):
self.assertEqual(item_a, item_b, '%s:\n%s\nvs.\n%s\n' % \
(idx, item_to_str(item_a), item_to_str(item_b)))
self.assertEqual(
source_dataset.categories(),
parsed_dataset.categories())
self._test_save_and_load(self.TestExtractor(),
DatumaroConverter(save_images=True), test_dir)
def test_can_detect(self):
with TestDir() as test_dir:
DatumaroConverter()(self.TestExtractor(), save_dir=test_dir)
self.assertTrue(DatumaroImporter.detect(test_dir))
self.assertTrue(DatumaroImporter.detect(test_dir))

Loading…
Cancel
Save