diff --git a/datumaro/datumaro/components/extractor.py b/datumaro/datumaro/components/extractor.py index 1f18af0f..7bb99a4f 100644 --- a/datumaro/datumaro/components/extractor.py +++ b/datumaro/datumaro/components/extractor.py @@ -576,7 +576,7 @@ class Caption(Annotation): class DatasetItem: # pylint: disable=redefined-builtin def __init__(self, id=None, annotations=None, - subset=None, path=None, image=None): + subset=None, path=None, image=None, attributes=None): assert id is not None self._id = str(id) @@ -604,6 +604,12 @@ class DatasetItem: image = Image(path=image) assert image is None or isinstance(image, Image) self._image = image + + if attributes is None: + attributes = {} + else: + attributes = dict(attributes) + self._attributes = attributes # pylint: enable=redefined-builtin @property @@ -630,6 +636,10 @@ class DatasetItem: def has_image(self): return self._image is not None + @property + def attributes(self): + return self._attributes + def __eq__(self, other): if not isinstance(other, __class__): return False @@ -638,10 +648,12 @@ class DatasetItem: (self.subset == other.subset) and \ (self.path == other.path) and \ (self.annotations == other.annotations) and \ - (self.image == other.image) + (self.image == other.image) and \ + (self.attributes == other.attributes) def wrap(item, **kwargs): - expected_args = {'id', 'annotations', 'subset', 'path', 'image'} + expected_args = {'id', 'annotations', 'subset', + 'path', 'image', 'attributes'} for k in expected_args: if k not in kwargs: kwargs[k] = getattr(item, k) diff --git a/datumaro/datumaro/plugins/datumaro_format/converter.py b/datumaro/datumaro/plugins/datumaro_format/converter.py index a1b7a1e7..b1067b69 100644 --- a/datumaro/datumaro/plugins/datumaro_format/converter.py +++ b/datumaro/datumaro/plugins/datumaro_format/converter.py @@ -49,6 +49,8 @@ class _SubsetWriter: 'id': item.id, 'annotations': annotations, } + if item.attributes: + item_desc['attr'] = item.attributes if item.path: item_desc['path'] = item.path if item.has_image: diff --git a/datumaro/datumaro/plugins/datumaro_format/extractor.py b/datumaro/datumaro/plugins/datumaro_format/extractor.py index 7be72b30..ed8813e6 100644 --- a/datumaro/datumaro/plugins/datumaro_format/extractor.py +++ b/datumaro/datumaro/plugins/datumaro_format/extractor.py @@ -91,13 +91,15 @@ class DatumaroExtractor(SourceExtractor): annotations = self._load_annotations(item_desc) item = DatasetItem(id=item_id, subset=self._subset, - annotations=annotations, image=image) + annotations=annotations, image=image, + attributes=item_desc.get('attr')) items.append(item) return items - def _load_annotations(self, item): + @staticmethod + def _load_annotations(item): parsed = item['annotations'] loaded = [] diff --git a/datumaro/datumaro/util/test_utils.py b/datumaro/datumaro/util/test_utils.py index 8600b621..2fb25c54 100644 --- a/datumaro/datumaro/util/test_utils.py +++ b/datumaro/datumaro/util/test_utils.py @@ -97,4 +97,21 @@ def compare_datasets(test, expected, actual): ann_b = find(ann_b_matches, lambda x: x == ann_a) test.assertEqual(ann_a, ann_b, 'ann: %s' % ann_to_str(ann_a)) - item_b.annotations.remove(ann_b) # avoid repeats \ No newline at end of file + item_b.annotations.remove(ann_b) # avoid repeats + +def compare_datasets_strict(test, expected, actual): + # Compares datasets for strong equality + + test.assertEqual(expected.categories(), actual.categories()) + + test.assertListEqual(sorted(expected.subsets()), sorted(actual.subsets())) + test.assertEqual(len(expected), len(actual)) + + for subset_name in expected.subsets(): + e_subset = expected.get_subset(subset_name) + a_subset = actual.get_subset(subset_name) + test.assertEqual(len(e_subset), len(a_subset)) + for idx, (item_a, item_b) in enumerate(zip(e_subset, a_subset)): + test.assertEqual(item_a, item_b, + '%s:\n%s\nvs.\n%s\n' % \ + (idx, item_to_str(item_a), item_to_str(item_b))) \ No newline at end of file diff --git a/datumaro/tests/test_datumaro_format.py b/datumaro/tests/test_datumaro_format.py index d6a81a53..f617465e 100644 --- a/datumaro/tests/test_datumaro_format.py +++ b/datumaro/tests/test_datumaro_format.py @@ -12,10 +12,24 @@ from datumaro.plugins.datumaro_format.importer import DatumaroImporter from datumaro.plugins.datumaro_format.converter import DatumaroConverter from datumaro.util.mask_tools import generate_colormap from datumaro.util.image import Image -from datumaro.util.test_utils import TestDir, item_to_str - +from datumaro.util.test_utils import TestDir, compare_datasets_strict class DatumaroConverterTest(TestCase): + def _test_save_and_load(self, source_dataset, converter, test_dir, + target_dataset=None, importer_args=None): + converter(source_dataset, test_dir) + + if importer_args is None: + importer_args = {} + parsed_dataset = Project.import_from( + test_dir, 'datumaro', **importer_args).make_dataset() + + if target_dataset is None: + target_dataset = source_dataset + + compare_datasets_strict(self, + expected=target_dataset, actual=parsed_dataset) + class TestExtractor(Extractor): def __iter__(self): return iter([ @@ -47,7 +61,8 @@ class DatumaroConverterTest(TestCase): Polygon([1, 2, 3, 4, 5, 6, 7, 8], id=12, z_order=4), ]), - DatasetItem(id=42, subset='test'), + DatasetItem(id=42, subset='test', + attributes={'a1': 5, 'a2': '42'}), DatasetItem(id=42), DatasetItem(id=43, image=Image(path='1/b/c.qq', size=(2, 4))), @@ -73,36 +88,11 @@ class DatumaroConverterTest(TestCase): def test_can_save_and_load(self): with TestDir() as test_dir: - source_dataset = self.TestExtractor() - - converter = DatumaroConverter(save_images=True) - converter(source_dataset, test_dir) - - project = Project.import_from(test_dir, 'datumaro') - parsed_dataset = project.make_dataset() - - self.assertListEqual( - sorted(source_dataset.subsets()), - sorted(parsed_dataset.subsets()), - ) - - self.assertEqual(len(source_dataset), len(parsed_dataset)) - - for subset_name in source_dataset.subsets(): - source_subset = source_dataset.get_subset(subset_name) - parsed_subset = parsed_dataset.get_subset(subset_name) - self.assertEqual(len(source_subset), len(parsed_subset)) - for idx, (item_a, item_b) in enumerate( - zip(source_subset, parsed_subset)): - self.assertEqual(item_a, item_b, '%s:\n%s\nvs.\n%s\n' % \ - (idx, item_to_str(item_a), item_to_str(item_b))) - - self.assertEqual( - source_dataset.categories(), - parsed_dataset.categories()) + self._test_save_and_load(self.TestExtractor(), + DatumaroConverter(save_images=True), test_dir) def test_can_detect(self): with TestDir() as test_dir: DatumaroConverter()(self.TestExtractor(), save_dir=test_dir) - self.assertTrue(DatumaroImporter.detect(test_dir)) \ No newline at end of file + self.assertTrue(DatumaroImporter.detect(test_dir))