[Datumaro] Fix project loading (#1013)

* Fix occasional infinite loop in project loading

* Fix project import source options saving

* Fix project import .git dir placement

* Make code aware of grayscale images
main
zhiltsov-max 6 years ago committed by Nikita Manovich
parent 4d730c756d
commit 337f7a0018

@ -80,10 +80,10 @@ class RISE:
def apply(self, image, progressive=False):
import cv2
assert len(image.shape) == 3, \
assert len(image.shape) in [2, 3], \
"Expected an input image in (H, W, C) format"
assert image.shape[2] in [3, 4], \
"Expected BGR or BGRA input"
if len(image.shape) == 3:
assert image.shape[2] in [3, 4], "Expected BGR or BGRA input"
image = image[:, :, :3].astype(np.float32)
model = self.model

@ -62,7 +62,7 @@ class _TaskConverter:
def save_image_info(self, item, filename):
if item.has_image:
h, w, _ = item.image.shape
h, w = item.image.shape[:2]
else:
h = 0
w = 0
@ -187,7 +187,7 @@ class _InstancesConverter(_TaskConverter):
p.label == ann.label]
if polygons:
segmentation = [p.get_points() for p in polygons]
h, w, _ = item.image.shape
h, w = item.image.shape[:2]
rles = mask_utils.frPyObjects(segmentation, h, w)
rle = mask_utils.merge(rles)
area = mask_utils.area(rle)
@ -211,7 +211,7 @@ class _InstancesConverter(_TaskConverter):
area = ann.area()
if self._context._merge_polygons:
h, w, _ = item.image.shape
h, w = item.image.shape[:2]
rles = mask_utils.frPyObjects(segmentation, h, w)
rle = mask_utils.merge(rles)
area = mask_utils.area(rle)

@ -48,7 +48,7 @@ def _make_tf_example(item, get_label_id, get_label, save_images=False):
if not item.has_image:
raise Exception(
"Failed to export dataset item '%s': item has no image" % item.id)
height, width, _ = item.image.shape
height, width = item.image.shape[:2]
features.update({
'image/height': int64_feature(height),

@ -153,7 +153,9 @@ class _Converter:
ET.SubElement(source_elem, 'image').text = 'Unknown'
if item.has_image:
h, w, c = item.image.shape
image_shape = item.image.shape
h, w = image_shape[:2]
c = 1 if len(image_shape) == 2 else image_shape[2]
size_elem = ET.SubElement(root_elem, 'size')
ET.SubElement(size_elem, 'width').text = str(w)
ET.SubElement(size_elem, 'height').text = str(h)

@ -92,7 +92,7 @@ class YoloConverter(Converter):
if not osp.exists(image_path):
save_image(image_path, item.image)
height, width, _ = item.image.shape
height, width = item.image.shape[:2]
yolo_annotation = ''
for bbox in item.annotations:

@ -43,7 +43,8 @@ class DatasetItemEncoder:
def encode_image(cls, image):
image_elem = ET.Element('image')
h, w, c = image.shape
h, w = image.shape[:2]
c = 1 if len(image.shape) == 2 else image.shape[2]
ET.SubElement(image_elem, 'width').text = str(w)
ET.SubElement(image_elem, 'height').text = str(h)
ET.SubElement(image_elem, 'depth').text = str(c)

@ -40,7 +40,7 @@ class CvatImporter:
project.add_source(subset_name, {
'url': subset_path,
'format': self.EXTRACTOR_NAME,
'options': extra_params,
'options': dict(extra_params),
})
return project

@ -40,7 +40,7 @@ class DatumaroImporter:
project.add_source(subset_name, {
'url': subset_path,
'format': self.EXTRACTOR_NAME,
'options': extra_params,
'options': dict(extra_params),
})
return project

@ -37,7 +37,7 @@ class CocoImporter:
project.add_source(source_name, {
'url': ann_file,
'format': self._COCO_EXTRACTORS[ann_type],
'options': extra_params,
'options': dict(extra_params),
})
return project

@ -35,7 +35,7 @@ class DetectionApiImporter:
project.add_source(subset_name, {
'url': subset_path,
'format': self.EXTRACTOR_NAME,
'options': extra_params,
'options': dict(extra_params),
})
return project

@ -31,7 +31,7 @@ class VocImporter:
project.add_source(task.name, {
'url': path,
'format': extractor_type,
'options': extra_params,
'options': dict(extra_params),
})
if len(project.config.sources) == 0:
@ -69,7 +69,7 @@ class VocResultsImporter:
project.add_source(task_name, {
'url': task_dir,
'format': extractor_type,
'options': extra_params,
'options': dict(extra_params),
})
if len(project.config.sources) == 0:

@ -28,7 +28,7 @@ class YoloImporter:
project.add_source(source_name, {
'url': config_path,
'format': 'yolo',
'options': extra_params,
'options': dict(extra_params),
})
return project

@ -97,7 +97,7 @@ class GitWrapper:
def __init__(self, config=None):
self.repo = None
if config is not None:
if config is not None and osp.isdir(config.project_dir):
self.init(config.project_dir)
@staticmethod
@ -335,7 +335,7 @@ class ProjectDataset(Extractor):
own_source = None
own_source_dir = osp.join(config.project_dir, config.dataset_dir)
if osp.isdir(own_source_dir):
if osp.isdir(config.project_dir) and osp.isdir(own_source_dir):
log.disable(log.INFO)
own_source = env.make_importer(DEFAULT_FORMAT)(own_source_dir) \
.make_dataset()

@ -273,6 +273,19 @@ class ProjectTest(TestCase):
self.assertEqual(5, len(dataset))
def test_can_save_and_load_own_dataset(self):
with TestDir() as test_dir:
src_project = Project()
src_dataset = src_project.make_dataset()
item = DatasetItem(id=1)
src_dataset.put(item)
src_dataset.save(test_dir.path)
loaded_project = Project.load(test_dir.path)
loaded_dataset = loaded_project.make_dataset()
self.assertEqual(list(src_dataset), list(loaded_dataset))
def test_project_own_dataset_can_be_modified(self):
project = Project()
dataset = project.make_dataset()

Loading…
Cancel
Save