You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
99 lines
3.8 KiB
Python
99 lines
3.8 KiB
Python
import numpy as np
|
|
|
|
from unittest import TestCase
|
|
|
|
from datumaro.components.project import Project
|
|
from datumaro.components.extractor import (Extractor, DatasetItem,
|
|
AnnotationType, Label, Mask, Points, Polygon,
|
|
PolyLine, Bbox, Caption,
|
|
LabelCategories, MaskCategories, PointsCategories
|
|
)
|
|
from datumaro.plugins.datumaro_format.importer import DatumaroImporter
|
|
from datumaro.plugins.datumaro_format.converter import DatumaroConverter
|
|
from datumaro.util.mask_tools import generate_colormap
|
|
from datumaro.util.image import Image
|
|
from datumaro.util.test_utils import TestDir, compare_datasets_strict
|
|
|
|
class DatumaroConverterTest(TestCase):
|
|
def _test_save_and_load(self, source_dataset, converter, test_dir,
|
|
target_dataset=None, importer_args=None):
|
|
converter(source_dataset, test_dir)
|
|
|
|
if importer_args is None:
|
|
importer_args = {}
|
|
parsed_dataset = Project.import_from(
|
|
test_dir, 'datumaro', **importer_args).make_dataset()
|
|
|
|
if target_dataset is None:
|
|
target_dataset = source_dataset
|
|
|
|
compare_datasets_strict(self,
|
|
expected=target_dataset, actual=parsed_dataset)
|
|
|
|
class TestExtractor(Extractor):
|
|
def __iter__(self):
|
|
return iter([
|
|
DatasetItem(id=100, subset='train', image=np.ones((10, 6, 3)),
|
|
annotations=[
|
|
Caption('hello', id=1),
|
|
Caption('world', id=2, group=5),
|
|
Label(2, id=3, attributes={
|
|
'x': 1,
|
|
'y': '2',
|
|
}),
|
|
Bbox(1, 2, 3, 4, label=4, id=4, z_order=1, attributes={
|
|
'score': 1.0,
|
|
}),
|
|
Bbox(5, 6, 7, 8, id=5, group=5),
|
|
Points([1, 2, 2, 0, 1, 1], label=0, id=5, z_order=4),
|
|
Mask(label=3, id=5, z_order=2, image=np.ones((2, 3))),
|
|
]),
|
|
DatasetItem(id=21, subset='train',
|
|
annotations=[
|
|
Caption('test'),
|
|
Label(2),
|
|
Bbox(1, 2, 3, 4, 5, id=42, group=42)
|
|
]),
|
|
|
|
DatasetItem(id=2, subset='val',
|
|
annotations=[
|
|
PolyLine([1, 2, 3, 4, 5, 6, 7, 8], id=11, z_order=1),
|
|
Polygon([1, 2, 3, 4, 5, 6, 7, 8], id=12, z_order=4),
|
|
]),
|
|
|
|
DatasetItem(id=42, subset='test',
|
|
attributes={'a1': 5, 'a2': '42'}),
|
|
|
|
DatasetItem(id=42),
|
|
DatasetItem(id=43, image=Image(path='1/b/c.qq', size=(2, 4))),
|
|
])
|
|
|
|
def categories(self):
|
|
label_categories = LabelCategories()
|
|
for i in range(5):
|
|
label_categories.add('cat' + str(i))
|
|
|
|
mask_categories = MaskCategories(
|
|
generate_colormap(len(label_categories.items)))
|
|
|
|
points_categories = PointsCategories()
|
|
for index, _ in enumerate(label_categories.items):
|
|
points_categories.add(index, ['cat1', 'cat2'], joints=[[0, 1]])
|
|
|
|
return {
|
|
AnnotationType.label: label_categories,
|
|
AnnotationType.mask: mask_categories,
|
|
AnnotationType.points: points_categories,
|
|
}
|
|
|
|
def test_can_save_and_load(self):
|
|
with TestDir() as test_dir:
|
|
self._test_save_and_load(self.TestExtractor(),
|
|
DatumaroConverter(save_images=True), test_dir)
|
|
|
|
def test_can_detect(self):
|
|
with TestDir() as test_dir:
|
|
DatumaroConverter()(self.TestExtractor(), save_dir=test_dir)
|
|
|
|
self.assertTrue(DatumaroImporter.detect(test_dir))
|