You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
436 lines
14 KiB
Python
436 lines
14 KiB
Python
import os
|
|
import os.path as osp
|
|
|
|
from unittest import TestCase
|
|
|
|
from datumaro.components.project import Project, Environment
|
|
from datumaro.components.project import Source, Model
|
|
from datumaro.components.launcher import Launcher, InferenceWrapper
|
|
from datumaro.components.converter import Converter
|
|
from datumaro.components.extractor import Extractor, DatasetItem, LabelObject
|
|
from datumaro.components.config import Config, DefaultConfig, SchemaBuilder
|
|
from datumaro.components.dataset_filter import XPathDatasetFilter
|
|
from datumaro.util.test_utils import TestDir
|
|
|
|
|
|
class ProjectTest(TestCase):
|
|
def test_project_generate(self):
|
|
src_config = Config({
|
|
'project_name': 'test_project',
|
|
'format_version': 1,
|
|
})
|
|
|
|
with TestDir() as test_dir:
|
|
project_path = test_dir.path
|
|
Project.generate(project_path, src_config)
|
|
|
|
self.assertTrue(osp.isdir(project_path))
|
|
|
|
result_config = Project.load(project_path).config
|
|
self.assertEqual(
|
|
src_config.project_name, result_config.project_name)
|
|
self.assertEqual(
|
|
src_config.format_version, result_config.format_version)
|
|
|
|
@staticmethod
|
|
def test_default_ctor_is_ok():
|
|
Project()
|
|
|
|
@staticmethod
|
|
def test_empty_config_is_ok():
|
|
Project(Config())
|
|
|
|
def test_add_source(self):
|
|
source_name = 'source'
|
|
origin = Source({
|
|
'url': 'path',
|
|
'format': 'ext'
|
|
})
|
|
project = Project()
|
|
|
|
project.add_source(source_name, origin)
|
|
|
|
added = project.get_source(source_name)
|
|
self.assertIsNotNone(added)
|
|
self.assertEqual(added, origin)
|
|
|
|
def test_added_source_can_be_saved(self):
|
|
source_name = 'source'
|
|
origin = Source({
|
|
'url': 'path',
|
|
})
|
|
project = Project()
|
|
project.add_source(source_name, origin)
|
|
|
|
saved = project.config
|
|
|
|
self.assertEqual(origin, saved.sources[source_name])
|
|
|
|
def test_added_source_can_be_dumped(self):
|
|
source_name = 'source'
|
|
origin = Source({
|
|
'url': 'path',
|
|
})
|
|
project = Project()
|
|
project.add_source(source_name, origin)
|
|
|
|
with TestDir() as test_dir:
|
|
project.save(test_dir.path)
|
|
|
|
loaded = Project.load(test_dir.path)
|
|
loaded = loaded.get_source(source_name)
|
|
self.assertEqual(origin, loaded)
|
|
|
|
def test_can_import_with_custom_importer(self):
|
|
class TestImporter:
|
|
def __call__(self, path, subset=None):
|
|
return Project({
|
|
'project_filename': path,
|
|
'subsets': [ subset ]
|
|
})
|
|
|
|
path = 'path'
|
|
importer_name = 'test_importer'
|
|
|
|
env = Environment()
|
|
env.importers.register(importer_name, TestImporter)
|
|
|
|
project = Project.import_from(path, importer_name, env,
|
|
subset='train')
|
|
|
|
self.assertEqual(path, project.config.project_filename)
|
|
self.assertListEqual(['train'], project.config.subsets)
|
|
|
|
def test_can_dump_added_model(self):
|
|
model_name = 'model'
|
|
|
|
project = Project()
|
|
saved = Model({ 'launcher': 'name' })
|
|
project.add_model(model_name, saved)
|
|
|
|
with TestDir() as test_dir:
|
|
project.save(test_dir.path)
|
|
|
|
loaded = Project.load(test_dir.path)
|
|
loaded = loaded.get_model(model_name)
|
|
self.assertEqual(saved, loaded)
|
|
|
|
def test_can_have_project_source(self):
|
|
with TestDir() as test_dir:
|
|
Project.generate(test_dir.path)
|
|
|
|
project2 = Project()
|
|
project2.add_source('project1', {
|
|
'url': test_dir.path,
|
|
})
|
|
dataset = project2.make_dataset()
|
|
|
|
self.assertTrue('project1' in dataset.sources)
|
|
|
|
def test_can_batch_launch_custom_model(self):
|
|
class TestExtractor(Extractor):
|
|
def __init__(self, url, n=0):
|
|
super().__init__(length=n)
|
|
self.n = n
|
|
|
|
def __iter__(self):
|
|
for i in range(self.n):
|
|
yield DatasetItem(id=i, subset='train', image=i)
|
|
|
|
class TestLauncher(Launcher):
|
|
def __init__(self, **kwargs):
|
|
pass
|
|
|
|
def launch(self, inputs):
|
|
for i, inp in enumerate(inputs):
|
|
yield [ LabelObject(attributes={'idx': i, 'data': inp}) ]
|
|
|
|
model_name = 'model'
|
|
launcher_name = 'custom_launcher'
|
|
|
|
project = Project()
|
|
project.env.launchers.register(launcher_name, TestLauncher)
|
|
project.add_model(model_name, { 'launcher': launcher_name })
|
|
model = project.make_executable_model(model_name)
|
|
extractor = TestExtractor('', n=5)
|
|
|
|
batch_size = 3
|
|
executor = InferenceWrapper(extractor, model, batch_size=batch_size)
|
|
|
|
for item in executor:
|
|
self.assertEqual(1, len(item.annotations))
|
|
self.assertEqual(int(item.id) % batch_size,
|
|
item.annotations[0].attributes['idx'])
|
|
self.assertEqual(int(item.id),
|
|
item.annotations[0].attributes['data'])
|
|
|
|
def test_can_do_transform_with_custom_model(self):
|
|
class TestExtractorSrc(Extractor):
|
|
def __init__(self, url, n=2):
|
|
super().__init__(length=n)
|
|
self.n = n
|
|
|
|
def __iter__(self):
|
|
for i in range(self.n):
|
|
yield DatasetItem(id=i, subset='train', image=i,
|
|
annotations=[ LabelObject(i) ])
|
|
|
|
class TestLauncher(Launcher):
|
|
def __init__(self, **kwargs):
|
|
pass
|
|
|
|
def launch(self, inputs):
|
|
for inp in inputs:
|
|
yield [ LabelObject(inp) ]
|
|
|
|
class TestConverter(Converter):
|
|
def __call__(self, extractor, save_dir):
|
|
for item in extractor:
|
|
with open(osp.join(save_dir, '%s.txt' % item.id), 'w+') as f:
|
|
f.write(str(item.subset) + '\n')
|
|
f.write(str(item.annotations[0].label) + '\n')
|
|
|
|
class TestExtractorDst(Extractor):
|
|
def __init__(self, url):
|
|
super().__init__()
|
|
self.items = [osp.join(url, p) for p in sorted(os.listdir(url))]
|
|
|
|
def __iter__(self):
|
|
for path in self.items:
|
|
with open(path, 'r') as f:
|
|
index = osp.splitext(osp.basename(path))[0]
|
|
subset = f.readline()[:-1]
|
|
label = int(f.readline()[:-1])
|
|
assert subset == 'train'
|
|
yield DatasetItem(id=index, subset=subset,
|
|
annotations=[ LabelObject(label) ])
|
|
|
|
model_name = 'model'
|
|
launcher_name = 'custom_launcher'
|
|
extractor_name = 'custom_extractor'
|
|
|
|
project = Project()
|
|
project.env.launchers.register(launcher_name, TestLauncher)
|
|
project.env.extractors.register(extractor_name, TestExtractorSrc)
|
|
project.env.converters.register(extractor_name, TestConverter)
|
|
project.add_model(model_name, { 'launcher': launcher_name })
|
|
project.add_source('source', { 'format': extractor_name })
|
|
|
|
with TestDir() as test_dir:
|
|
project.make_dataset().transform(model_name, test_dir.path)
|
|
|
|
result = Project.load(test_dir.path)
|
|
result.env.extractors.register(extractor_name, TestExtractorDst)
|
|
it = iter(result.make_dataset())
|
|
item1 = next(it)
|
|
item2 = next(it)
|
|
self.assertEqual(0, item1.annotations[0].label)
|
|
self.assertEqual(1, item2.annotations[0].label)
|
|
|
|
def test_source_datasets_can_be_merged(self):
|
|
class TestExtractor(Extractor):
|
|
def __init__(self, url, n=0, s=0):
|
|
super().__init__(length=n)
|
|
self.n = n
|
|
self.s = s
|
|
|
|
def __iter__(self):
|
|
for i in range(self.n):
|
|
yield DatasetItem(id=self.s + i, subset='train')
|
|
|
|
e_name1 = 'e1'
|
|
e_name2 = 'e2'
|
|
n1 = 2
|
|
n2 = 4
|
|
|
|
project = Project()
|
|
project.env.extractors.register(e_name1, lambda p: TestExtractor(p, n=n1))
|
|
project.env.extractors.register(e_name2, lambda p: TestExtractor(p, n=n2, s=n1))
|
|
project.add_source('source1', { 'format': e_name1 })
|
|
project.add_source('source2', { 'format': e_name2 })
|
|
|
|
dataset = project.make_dataset()
|
|
|
|
self.assertEqual(n1 + n2, len(dataset))
|
|
|
|
def test_project_filter_can_be_applied(self):
|
|
class TestExtractor(Extractor):
|
|
def __init__(self, url, n=10):
|
|
super().__init__(length=n)
|
|
self.n = n
|
|
|
|
def __iter__(self):
|
|
for i in range(self.n):
|
|
yield DatasetItem(id=i, subset='train')
|
|
|
|
e_type = 'type'
|
|
project = Project()
|
|
project.env.extractors.register(e_type, TestExtractor)
|
|
project.add_source('source', { 'format': e_type })
|
|
project.set_filter('/item[id < 5]')
|
|
|
|
dataset = project.make_dataset()
|
|
|
|
self.assertEqual(5, len(dataset))
|
|
|
|
def test_can_save_and_load_own_dataset(self):
|
|
with TestDir() as test_dir:
|
|
src_project = Project()
|
|
src_dataset = src_project.make_dataset()
|
|
item = DatasetItem(id=1)
|
|
src_dataset.put(item)
|
|
src_dataset.save(test_dir.path)
|
|
|
|
loaded_project = Project.load(test_dir.path)
|
|
loaded_dataset = loaded_project.make_dataset()
|
|
|
|
self.assertEqual(list(src_dataset), list(loaded_dataset))
|
|
|
|
def test_project_own_dataset_can_be_modified(self):
|
|
project = Project()
|
|
dataset = project.make_dataset()
|
|
|
|
item = DatasetItem(id=1)
|
|
dataset.put(item)
|
|
|
|
self.assertEqual(item, next(iter(dataset)))
|
|
|
|
def test_project_compound_child_can_be_modified_recursively(self):
|
|
with TestDir() as test_dir:
|
|
child1 = Project({
|
|
'project_dir': osp.join(test_dir.path, 'child1'),
|
|
})
|
|
child1.save()
|
|
|
|
child2 = Project({
|
|
'project_dir': osp.join(test_dir.path, 'child2'),
|
|
})
|
|
child2.save()
|
|
|
|
parent = Project()
|
|
parent.add_source('child1', {
|
|
'url': child1.config.project_dir
|
|
})
|
|
parent.add_source('child2', {
|
|
'url': child2.config.project_dir
|
|
})
|
|
dataset = parent.make_dataset()
|
|
|
|
item1 = DatasetItem(id='ch1', path=['child1'])
|
|
item2 = DatasetItem(id='ch2', path=['child2'])
|
|
dataset.put(item1)
|
|
dataset.put(item2)
|
|
|
|
self.assertEqual(2, len(dataset))
|
|
self.assertEqual(1, len(dataset.sources['child1']))
|
|
self.assertEqual(1, len(dataset.sources['child2']))
|
|
|
|
def test_project_can_merge_item_annotations(self):
|
|
class TestExtractor(Extractor):
|
|
def __init__(self, url, v=None):
|
|
super().__init__()
|
|
self.v = v
|
|
|
|
def __iter__(self):
|
|
v1_item = DatasetItem(id=1, subset='train', annotations=[
|
|
LabelObject(2, id=3),
|
|
LabelObject(3, attributes={ 'x': 1 }),
|
|
])
|
|
|
|
v2_item = DatasetItem(id=1, subset='train', annotations=[
|
|
LabelObject(3, attributes={ 'x': 1 }),
|
|
LabelObject(4, id=4),
|
|
])
|
|
|
|
if self.v == 1:
|
|
yield v1_item
|
|
else:
|
|
yield v2_item
|
|
|
|
project = Project()
|
|
project.env.extractors.register('t1', lambda p: TestExtractor(p, v=1))
|
|
project.env.extractors.register('t2', lambda p: TestExtractor(p, v=2))
|
|
project.add_source('source1', { 'format': 't1' })
|
|
project.add_source('source2', { 'format': 't2' })
|
|
|
|
merged = project.make_dataset()
|
|
|
|
self.assertEqual(1, len(merged))
|
|
|
|
item = next(iter(merged))
|
|
self.assertEqual(3, len(item.annotations))
|
|
|
|
class DatasetFilterTest(TestCase):
|
|
class TestExtractor(Extractor):
|
|
def __init__(self, url, n=0):
|
|
super().__init__(length=n)
|
|
self.n = n
|
|
|
|
def __iter__(self):
|
|
for i in range(self.n):
|
|
yield DatasetItem(id=i, subset='train')
|
|
|
|
def test_xpathfilter_can_be_applied(self):
|
|
extractor = self.TestExtractor('', n=4)
|
|
dataset_filter = XPathDatasetFilter('/item[id > 1]')
|
|
|
|
filtered = extractor.select(dataset_filter)
|
|
|
|
self.assertEqual(2, len(filtered))
|
|
|
|
class ConfigTest(TestCase):
|
|
def test_can_produce_multilayer_config_from_dict(self):
|
|
schema_low = SchemaBuilder() \
|
|
.add('options', dict) \
|
|
.build()
|
|
schema_mid = SchemaBuilder() \
|
|
.add('desc', lambda: Config(schema=schema_low)) \
|
|
.build()
|
|
schema_top = SchemaBuilder() \
|
|
.add('container', lambda: DefaultConfig(
|
|
lambda v: Config(v, schema=schema_mid))) \
|
|
.build()
|
|
|
|
value = 1
|
|
source = Config({
|
|
'container': {
|
|
'elem': {
|
|
'desc': {
|
|
'options': {
|
|
'k': value
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}, schema=schema_top)
|
|
|
|
self.assertEqual(value, source.container['elem'].desc.options['k'])
|
|
|
|
class ExtractorTest(TestCase):
|
|
def test_custom_extractor_can_be_created(self):
|
|
class CustomExtractor(Extractor):
|
|
def __init__(self, url):
|
|
super().__init__()
|
|
|
|
def __iter__(self):
|
|
return iter([
|
|
DatasetItem(id=0, subset='train'),
|
|
DatasetItem(id=1, subset='train'),
|
|
DatasetItem(id=2, subset='train'),
|
|
|
|
DatasetItem(id=3, subset='test'),
|
|
])
|
|
|
|
extractor_name = 'ext1'
|
|
project = Project()
|
|
project.env.extractors.register(extractor_name, CustomExtractor)
|
|
project.add_source('src1', {
|
|
'url': 'path',
|
|
'format': extractor_name,
|
|
})
|
|
project.set_subsets(['train'])
|
|
|
|
dataset = project.make_dataset()
|
|
|
|
self.assertEqual(3, len(dataset))
|