You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
228 lines
7.6 KiB
Python
228 lines
7.6 KiB
Python
import numpy as np
|
|
import os.path as osp
|
|
|
|
from unittest import TestCase, skipIf
|
|
|
|
from datumaro.components.extractor import (Extractor, DatasetItem,
|
|
AnnotationType, Bbox, Mask, LabelCategories
|
|
)
|
|
from datumaro.components.project import Project
|
|
from datumaro.util.image import Image
|
|
from datumaro.util.test_utils import TestDir, compare_datasets
|
|
from datumaro.util.tf_util import check_import
|
|
|
|
try:
|
|
from datumaro.plugins.tf_detection_api_format.importer import TfDetectionApiImporter
|
|
from datumaro.plugins.tf_detection_api_format.extractor import TfDetectionApiExtractor
|
|
from datumaro.plugins.tf_detection_api_format.converter import TfDetectionApiConverter
|
|
import_failed = False
|
|
except ImportError:
|
|
import_failed = True
|
|
|
|
import importlib
|
|
module_found = importlib.util.find_spec('tensorflow') is not None
|
|
|
|
@skipIf(not module_found, "Tensorflow package is not found")
|
|
class TfImportTest(TestCase):
|
|
def test_raises_when_crashes_on_import(self):
|
|
# Should fire if import can't be done for any reason except
|
|
# module unavailability and import crash
|
|
with self.assertRaisesRegex(ImportError, 'Test process exit code'):
|
|
check_import()
|
|
|
|
@skipIf(import_failed, "Failed to import tensorflow")
|
|
class TfrecordConverterTest(TestCase):
|
|
def _test_save_and_load(self, source_dataset, converter, test_dir,
|
|
target_dataset=None, importer_args=None):
|
|
converter(source_dataset, test_dir)
|
|
|
|
if importer_args is None:
|
|
importer_args = {}
|
|
parsed_dataset = TfDetectionApiImporter()(test_dir, **importer_args) \
|
|
.make_dataset()
|
|
|
|
if target_dataset is None:
|
|
target_dataset = source_dataset
|
|
|
|
compare_datasets(self, expected=target_dataset, actual=parsed_dataset)
|
|
|
|
def test_can_save_bboxes(self):
|
|
class TestExtractor(Extractor):
|
|
def __iter__(self):
|
|
return iter([
|
|
DatasetItem(id=1, subset='train',
|
|
image=np.ones((16, 16, 3)),
|
|
annotations=[
|
|
Bbox(0, 4, 4, 8, label=2),
|
|
Bbox(0, 4, 4, 4, label=3),
|
|
Bbox(2, 4, 4, 4),
|
|
]
|
|
),
|
|
])
|
|
|
|
def categories(self):
|
|
label_cat = LabelCategories()
|
|
for label in range(10):
|
|
label_cat.add('label_' + str(label))
|
|
return {
|
|
AnnotationType.label: label_cat,
|
|
}
|
|
|
|
with TestDir() as test_dir:
|
|
self._test_save_and_load(
|
|
TestExtractor(), TfDetectionApiConverter(save_images=True),
|
|
test_dir)
|
|
|
|
def test_can_save_masks(self):
|
|
class TestExtractor(Extractor):
|
|
def __iter__(self):
|
|
return iter([
|
|
DatasetItem(id=1, subset='train', image=np.ones((4, 5, 3)),
|
|
annotations=[
|
|
Mask(image=np.array([
|
|
[1, 0, 0, 1],
|
|
[0, 1, 1, 0],
|
|
[0, 1, 1, 0],
|
|
[1, 0, 0, 1],
|
|
]), label=1),
|
|
]
|
|
),
|
|
])
|
|
|
|
def categories(self):
|
|
label_cat = LabelCategories()
|
|
for label in range(10):
|
|
label_cat.add('label_' + str(label))
|
|
return {
|
|
AnnotationType.label: label_cat,
|
|
}
|
|
|
|
with TestDir() as test_dir:
|
|
self._test_save_and_load(
|
|
TestExtractor(), TfDetectionApiConverter(save_masks=True),
|
|
test_dir)
|
|
|
|
def test_can_save_dataset_with_no_subsets(self):
|
|
class TestExtractor(Extractor):
|
|
def __iter__(self):
|
|
return iter([
|
|
DatasetItem(id=1,
|
|
image=np.ones((16, 16, 3)),
|
|
annotations=[
|
|
Bbox(2, 1, 4, 4, label=2),
|
|
Bbox(4, 2, 8, 4, label=3),
|
|
]
|
|
),
|
|
|
|
DatasetItem(id=2,
|
|
image=np.ones((8, 8, 3)) * 2,
|
|
annotations=[
|
|
Bbox(4, 4, 4, 4, label=3),
|
|
]
|
|
),
|
|
|
|
DatasetItem(id=3,
|
|
image=np.ones((8, 4, 3)) * 3,
|
|
),
|
|
])
|
|
|
|
def categories(self):
|
|
label_cat = LabelCategories()
|
|
for label in range(10):
|
|
label_cat.add('label_' + str(label))
|
|
return {
|
|
AnnotationType.label: label_cat,
|
|
}
|
|
|
|
with TestDir() as test_dir:
|
|
self._test_save_and_load(
|
|
TestExtractor(), TfDetectionApiConverter(save_images=True),
|
|
test_dir)
|
|
|
|
def test_can_save_dataset_with_image_info(self):
|
|
class TestExtractor(Extractor):
|
|
def __iter__(self):
|
|
return iter([
|
|
DatasetItem(id=1, image=Image(path='1/q.e', size=(10, 15))),
|
|
])
|
|
|
|
def categories(self):
|
|
return { AnnotationType.label: LabelCategories() }
|
|
|
|
with TestDir() as test_dir:
|
|
self._test_save_and_load(TestExtractor(),
|
|
TfDetectionApiConverter(), test_dir)
|
|
|
|
def test_labelmap_parsing(self):
|
|
text = """
|
|
{
|
|
id: 4
|
|
name: 'qw1'
|
|
}
|
|
{
|
|
id: 5 name: 'qw2'
|
|
}
|
|
|
|
{
|
|
name: 'qw3'
|
|
id: 6
|
|
}
|
|
{name:'qw4' id:7}
|
|
"""
|
|
expected = {
|
|
'qw1': 4,
|
|
'qw2': 5,
|
|
'qw3': 6,
|
|
'qw4': 7,
|
|
}
|
|
parsed = TfDetectionApiExtractor._parse_labelmap(text)
|
|
|
|
self.assertEqual(expected, parsed)
|
|
|
|
|
|
DUMMY_DATASET_DIR = osp.join(osp.dirname(__file__),
|
|
'assets', 'tf_detection_api_dataset')
|
|
|
|
@skipIf(import_failed, "Failed to import tensorflow")
|
|
class TfrecordImporterTest(TestCase):
|
|
def test_can_detect(self):
|
|
self.assertTrue(TfDetectionApiImporter.detect(DUMMY_DATASET_DIR))
|
|
|
|
def test_can_import(self):
|
|
class DstExtractor(Extractor):
|
|
def __iter__(self):
|
|
return iter([
|
|
DatasetItem(id=1, subset='train',
|
|
image=np.ones((16, 16, 3)),
|
|
annotations=[
|
|
Bbox(0, 4, 4, 8, label=2),
|
|
Bbox(0, 4, 4, 4, label=3),
|
|
Bbox(2, 4, 4, 4),
|
|
],
|
|
),
|
|
|
|
DatasetItem(id=2, subset='val',
|
|
image=np.ones((8, 8, 3)),
|
|
annotations=[
|
|
Bbox(1, 2, 4, 2, label=3),
|
|
],
|
|
),
|
|
|
|
DatasetItem(id=3, subset='test',
|
|
image=np.ones((5, 4, 3)) * 3,
|
|
),
|
|
])
|
|
|
|
def categories(self):
|
|
label_cat = LabelCategories()
|
|
for label in range(10):
|
|
label_cat.add('label_' + str(label))
|
|
return {
|
|
AnnotationType.label: label_cat,
|
|
}
|
|
|
|
dataset = Project.import_from(DUMMY_DATASET_DIR, 'tf_detection_api') \
|
|
.make_dataset()
|
|
|
|
compare_datasets(self, DstExtractor(), dataset)
|