import numpy as np import os import os.path as osp from unittest import TestCase from datumaro.components.project import Project, Environment from datumaro.components.project import Source, Model from datumaro.components.launcher import Launcher, InferenceWrapper from datumaro.components.converter import Converter from datumaro.components.extractor import (Extractor, DatasetItem, LabelObject, MaskObject, PointsObject, PolygonObject, PolyLineObject, BboxObject, CaptionObject, ) from datumaro.components.config import Config, DefaultConfig, SchemaBuilder from datumaro.components.dataset_filter import \ XPathDatasetFilter, XPathAnnotationsFilter, DatasetItemEncoder from datumaro.util.test_utils import TestDir class ProjectTest(TestCase): def test_project_generate(self): src_config = Config({ 'project_name': 'test_project', 'format_version': 1, }) with TestDir() as test_dir: project_path = test_dir.path Project.generate(project_path, src_config) self.assertTrue(osp.isdir(project_path)) result_config = Project.load(project_path).config self.assertEqual( src_config.project_name, result_config.project_name) self.assertEqual( src_config.format_version, result_config.format_version) @staticmethod def test_default_ctor_is_ok(): Project() @staticmethod def test_empty_config_is_ok(): Project(Config()) def test_add_source(self): source_name = 'source' origin = Source({ 'url': 'path', 'format': 'ext' }) project = Project() project.add_source(source_name, origin) added = project.get_source(source_name) self.assertIsNotNone(added) self.assertEqual(added, origin) def test_added_source_can_be_saved(self): source_name = 'source' origin = Source({ 'url': 'path', }) project = Project() project.add_source(source_name, origin) saved = project.config self.assertEqual(origin, saved.sources[source_name]) def test_added_source_can_be_dumped(self): source_name = 'source' origin = Source({ 'url': 'path', }) project = Project() project.add_source(source_name, origin) with TestDir() as test_dir: project.save(test_dir.path) loaded = Project.load(test_dir.path) loaded = loaded.get_source(source_name) self.assertEqual(origin, loaded) def test_can_import_with_custom_importer(self): class TestImporter: def __call__(self, path, subset=None): return Project({ 'project_filename': path, 'subsets': [ subset ] }) path = 'path' importer_name = 'test_importer' env = Environment() env.importers.register(importer_name, TestImporter) project = Project.import_from(path, importer_name, env, subset='train') self.assertEqual(path, project.config.project_filename) self.assertListEqual(['train'], project.config.subsets) def test_can_dump_added_model(self): model_name = 'model' project = Project() saved = Model({ 'launcher': 'name' }) project.add_model(model_name, saved) with TestDir() as test_dir: project.save(test_dir.path) loaded = Project.load(test_dir.path) loaded = loaded.get_model(model_name) self.assertEqual(saved, loaded) def test_can_have_project_source(self): with TestDir() as test_dir: Project.generate(test_dir.path) project2 = Project() project2.add_source('project1', { 'url': test_dir.path, }) dataset = project2.make_dataset() self.assertTrue('project1' in dataset.sources) def test_can_batch_launch_custom_model(self): class TestExtractor(Extractor): def __iter__(self): for i in range(5): yield DatasetItem(id=i, subset='train', image=i) class TestLauncher(Launcher): def launch(self, inputs): for i, inp in enumerate(inputs): yield [ LabelObject(attributes={'idx': i, 'data': inp}) ] model_name = 'model' launcher_name = 'custom_launcher' project = Project() project.env.launchers.register(launcher_name, TestLauncher) project.add_model(model_name, { 'launcher': launcher_name }) model = project.make_executable_model(model_name) extractor = TestExtractor() batch_size = 3 executor = InferenceWrapper(extractor, model, batch_size=batch_size) for item in executor: self.assertEqual(1, len(item.annotations)) self.assertEqual(int(item.id) % batch_size, item.annotations[0].attributes['idx']) self.assertEqual(int(item.id), item.annotations[0].attributes['data']) def test_can_do_transform_with_custom_model(self): class TestExtractorSrc(Extractor): def __iter__(self): for i in range(2): yield DatasetItem(id=i, subset='train', image=i, annotations=[ LabelObject(i) ]) class TestLauncher(Launcher): def launch(self, inputs): for inp in inputs: yield [ LabelObject(inp) ] class TestConverter(Converter): def __call__(self, extractor, save_dir): for item in extractor: with open(osp.join(save_dir, '%s.txt' % item.id), 'w') as f: f.write(str(item.subset) + '\n') f.write(str(item.annotations[0].label) + '\n') class TestExtractorDst(Extractor): def __init__(self, url): super().__init__() self.items = [osp.join(url, p) for p in sorted(os.listdir(url))] def __iter__(self): for path in self.items: with open(path, 'r') as f: index = osp.splitext(osp.basename(path))[0] subset = f.readline().strip() label = int(f.readline().strip()) assert subset == 'train' yield DatasetItem(id=index, subset=subset, annotations=[ LabelObject(label) ]) model_name = 'model' launcher_name = 'custom_launcher' extractor_name = 'custom_extractor' project = Project() project.env.launchers.register(launcher_name, TestLauncher) project.env.extractors.register(extractor_name, TestExtractorSrc) project.env.converters.register(extractor_name, TestConverter) project.add_model(model_name, { 'launcher': launcher_name }) project.add_source('source', { 'format': extractor_name }) with TestDir() as test_dir: project.make_dataset().apply_model(model_name=model_name, save_dir=test_dir.path) result = Project.load(test_dir.path) result.env.extractors.register(extractor_name, TestExtractorDst) it = iter(result.make_dataset()) item1 = next(it) item2 = next(it) self.assertEqual(0, item1.annotations[0].label) self.assertEqual(1, item2.annotations[0].label) def test_source_datasets_can_be_merged(self): class TestExtractor(Extractor): def __init__(self, url, n=0, s=0): super().__init__(length=n) self.n = n self.s = s def __iter__(self): for i in range(self.n): yield DatasetItem(id=self.s + i, subset='train') e_name1 = 'e1' e_name2 = 'e2' n1 = 2 n2 = 4 project = Project() project.env.extractors.register(e_name1, lambda p: TestExtractor(p, n=n1)) project.env.extractors.register(e_name2, lambda p: TestExtractor(p, n=n2, s=n1)) project.add_source('source1', { 'format': e_name1 }) project.add_source('source2', { 'format': e_name2 }) dataset = project.make_dataset() self.assertEqual(n1 + n2, len(dataset)) def test_project_filter_can_be_applied(self): class TestExtractor(Extractor): def __iter__(self): for i in range(10): yield DatasetItem(id=i, subset='train') e_type = 'type' project = Project() project.env.extractors.register(e_type, TestExtractor) project.add_source('source', { 'format': e_type }) dataset = project.make_dataset().extract('/item[id < 5]') self.assertEqual(5, len(dataset)) def test_can_save_and_load_own_dataset(self): with TestDir() as test_dir: src_project = Project() src_dataset = src_project.make_dataset() item = DatasetItem(id=1) src_dataset.put(item) src_dataset.save(test_dir.path) loaded_project = Project.load(test_dir.path) loaded_dataset = loaded_project.make_dataset() self.assertEqual(list(src_dataset), list(loaded_dataset)) def test_project_own_dataset_can_be_modified(self): project = Project() dataset = project.make_dataset() item = DatasetItem(id=1) dataset.put(item) self.assertEqual(item, next(iter(dataset))) def test_project_compound_child_can_be_modified_recursively(self): with TestDir() as test_dir: child1 = Project({ 'project_dir': osp.join(test_dir.path, 'child1'), }) child1.save() child2 = Project({ 'project_dir': osp.join(test_dir.path, 'child2'), }) child2.save() parent = Project() parent.add_source('child1', { 'url': child1.config.project_dir }) parent.add_source('child2', { 'url': child2.config.project_dir }) dataset = parent.make_dataset() item1 = DatasetItem(id='ch1', path=['child1']) item2 = DatasetItem(id='ch2', path=['child2']) dataset.put(item1) dataset.put(item2) self.assertEqual(2, len(dataset)) self.assertEqual(1, len(dataset.sources['child1'])) self.assertEqual(1, len(dataset.sources['child2'])) def test_project_can_merge_item_annotations(self): class TestExtractor1(Extractor): def __iter__(self): yield DatasetItem(id=1, subset='train', annotations=[ LabelObject(2, id=3), LabelObject(3, attributes={ 'x': 1 }), ]) class TestExtractor2(Extractor): def __iter__(self): yield DatasetItem(id=1, subset='train', annotations=[ LabelObject(3, attributes={ 'x': 1 }), LabelObject(4, id=4), ]) project = Project() project.env.extractors.register('t1', TestExtractor1) project.env.extractors.register('t2', TestExtractor2) project.add_source('source1', { 'format': 't1' }) project.add_source('source2', { 'format': 't2' }) merged = project.make_dataset() self.assertEqual(1, len(merged)) item = next(iter(merged)) self.assertEqual(3, len(item.annotations)) class DatasetFilterTest(TestCase): @staticmethod def test_item_representations(): item = DatasetItem(id=1, subset='subset', path=['a', 'b'], image=np.ones((5, 4, 3)), annotations=[ LabelObject(0, attributes={'a1': 1, 'a2': '2'}, id=1, group=2), CaptionObject('hello', id=1), CaptionObject('world', group=5), LabelObject(2, id=3, attributes={ 'x': 1, 'y': '2' }), BboxObject(1, 2, 3, 4, label=4, id=4, attributes={ 'a': 1.0 }), BboxObject(5, 6, 7, 8, id=5, group=5), PointsObject([1, 2, 2, 0, 1, 1], label=0, id=5), MaskObject(id=5, image=np.ones((3, 2))), MaskObject(label=3, id=5, image=np.ones((2, 3))), PolyLineObject([1, 2, 3, 4, 5, 6, 7, 8], id=11), PolygonObject([1, 2, 3, 4, 5, 6, 7, 8]), ] ) encoded = DatasetItemEncoder.encode(item) DatasetItemEncoder.to_string(encoded) def test_item_filter_can_be_applied(self): class TestExtractor(Extractor): def __iter__(self): for i in range(4): yield DatasetItem(id=i, subset='train') extractor = TestExtractor() filtered = XPathDatasetFilter(extractor, '/item[id > 1]') self.assertEqual(2, len(filtered)) def test_annotations_filter_can_be_applied(self): class SrcTestExtractor(Extractor): def __iter__(self): return iter([ DatasetItem(id=0), DatasetItem(id=1, annotations=[ LabelObject(0), LabelObject(1), ]), DatasetItem(id=2, annotations=[ LabelObject(0), LabelObject(2), ]), ]) class DstTestExtractor(Extractor): def __iter__(self): return iter([ DatasetItem(id=0), DatasetItem(id=1, annotations=[ LabelObject(0), ]), DatasetItem(id=2, annotations=[ LabelObject(0), ]), ]) extractor = SrcTestExtractor() filtered = XPathAnnotationsFilter(extractor, '/item/annotation[label_id = 0]') self.assertListEqual(list(filtered), list(DstTestExtractor())) def test_annotations_filter_can_remove_empty_items(self): class SrcTestExtractor(Extractor): def __iter__(self): return iter([ DatasetItem(id=0), DatasetItem(id=1, annotations=[ LabelObject(0), LabelObject(1), ]), DatasetItem(id=2, annotations=[ LabelObject(0), LabelObject(2), ]), ]) class DstTestExtractor(Extractor): def __iter__(self): return iter([ DatasetItem(id=2, annotations=[ LabelObject(2), ]), ]) extractor = SrcTestExtractor() filtered = XPathAnnotationsFilter(extractor, '/item/annotation[label_id = 2]', remove_empty=True) self.assertListEqual(list(filtered), list(DstTestExtractor())) class ConfigTest(TestCase): def test_can_produce_multilayer_config_from_dict(self): schema_low = SchemaBuilder() \ .add('options', dict) \ .build() schema_mid = SchemaBuilder() \ .add('desc', lambda: Config(schema=schema_low)) \ .build() schema_top = SchemaBuilder() \ .add('container', lambda: DefaultConfig( lambda v: Config(v, schema=schema_mid))) \ .build() value = 1 source = Config({ 'container': { 'elem': { 'desc': { 'options': { 'k': value } } } } }, schema=schema_top) self.assertEqual(value, source.container['elem'].desc.options['k']) class ExtractorTest(TestCase): def test_custom_extractor_can_be_created(self): class CustomExtractor(Extractor): def __iter__(self): return iter([ DatasetItem(id=0, subset='train'), DatasetItem(id=1, subset='train'), DatasetItem(id=2, subset='train'), DatasetItem(id=3, subset='test'), ]) extractor_name = 'ext1' project = Project() project.env.extractors.register(extractor_name, CustomExtractor) project.add_source('src1', { 'url': 'path', 'format': extractor_name, }) project.set_subsets(['train']) dataset = project.make_dataset() self.assertEqual(3, len(dataset))