diff --git a/datumaro/datumaro/components/algorithms/rise.py b/datumaro/datumaro/components/algorithms/rise.py index 78e93639..8e75f10a 100644 --- a/datumaro/datumaro/components/algorithms/rise.py +++ b/datumaro/datumaro/components/algorithms/rise.py @@ -80,10 +80,10 @@ class RISE: def apply(self, image, progressive=False): import cv2 - assert len(image.shape) == 3, \ + assert len(image.shape) in [2, 3], \ "Expected an input image in (H, W, C) format" - assert image.shape[2] in [3, 4], \ - "Expected BGR or BGRA input" + if len(image.shape) == 3: + assert image.shape[2] in [3, 4], "Expected BGR or BGRA input" image = image[:, :, :3].astype(np.float32) model = self.model diff --git a/datumaro/datumaro/components/converters/ms_coco.py b/datumaro/datumaro/components/converters/ms_coco.py index f629f72d..469184ac 100644 --- a/datumaro/datumaro/components/converters/ms_coco.py +++ b/datumaro/datumaro/components/converters/ms_coco.py @@ -62,7 +62,7 @@ class _TaskConverter: def save_image_info(self, item, filename): if item.has_image: - h, w, _ = item.image.shape + h, w = item.image.shape[:2] else: h = 0 w = 0 @@ -187,7 +187,7 @@ class _InstancesConverter(_TaskConverter): p.label == ann.label] if polygons: segmentation = [p.get_points() for p in polygons] - h, w, _ = item.image.shape + h, w = item.image.shape[:2] rles = mask_utils.frPyObjects(segmentation, h, w) rle = mask_utils.merge(rles) area = mask_utils.area(rle) @@ -211,7 +211,7 @@ class _InstancesConverter(_TaskConverter): area = ann.area() if self._context._merge_polygons: - h, w, _ = item.image.shape + h, w = item.image.shape[:2] rles = mask_utils.frPyObjects(segmentation, h, w) rle = mask_utils.merge(rles) area = mask_utils.area(rle) diff --git a/datumaro/datumaro/components/converters/tfrecord.py b/datumaro/datumaro/components/converters/tfrecord.py index 447a8359..7d6c5c19 100644 --- a/datumaro/datumaro/components/converters/tfrecord.py +++ b/datumaro/datumaro/components/converters/tfrecord.py @@ -48,7 +48,7 @@ def _make_tf_example(item, get_label_id, get_label, save_images=False): if not item.has_image: raise Exception( "Failed to export dataset item '%s': item has no image" % item.id) - height, width, _ = item.image.shape + height, width = item.image.shape[:2] features.update({ 'image/height': int64_feature(height), diff --git a/datumaro/datumaro/components/converters/voc.py b/datumaro/datumaro/components/converters/voc.py index 81003678..c296c351 100644 --- a/datumaro/datumaro/components/converters/voc.py +++ b/datumaro/datumaro/components/converters/voc.py @@ -153,7 +153,9 @@ class _Converter: ET.SubElement(source_elem, 'image').text = 'Unknown' if item.has_image: - h, w, c = item.image.shape + image_shape = item.image.shape + h, w = image_shape[:2] + c = 1 if len(image_shape) == 2 else image_shape[2] size_elem = ET.SubElement(root_elem, 'size') ET.SubElement(size_elem, 'width').text = str(w) ET.SubElement(size_elem, 'height').text = str(h) diff --git a/datumaro/datumaro/components/converters/yolo.py b/datumaro/datumaro/components/converters/yolo.py index 4bf74693..cf0d1db7 100644 --- a/datumaro/datumaro/components/converters/yolo.py +++ b/datumaro/datumaro/components/converters/yolo.py @@ -92,7 +92,7 @@ class YoloConverter(Converter): if not osp.exists(image_path): save_image(image_path, item.image) - height, width, _ = item.image.shape + height, width = item.image.shape[:2] yolo_annotation = '' for bbox in item.annotations: diff --git a/datumaro/datumaro/components/dataset_filter.py b/datumaro/datumaro/components/dataset_filter.py index 157720f3..28339df0 100644 --- a/datumaro/datumaro/components/dataset_filter.py +++ b/datumaro/datumaro/components/dataset_filter.py @@ -43,7 +43,8 @@ class DatasetItemEncoder: def encode_image(cls, image): image_elem = ET.Element('image') - h, w, c = image.shape + h, w = image.shape[:2] + c = 1 if len(image.shape) == 2 else image.shape[2] ET.SubElement(image_elem, 'width').text = str(w) ET.SubElement(image_elem, 'height').text = str(h) ET.SubElement(image_elem, 'depth').text = str(c) diff --git a/datumaro/datumaro/components/importers/cvat.py b/datumaro/datumaro/components/importers/cvat.py index efdeff29..6f831a7b 100644 --- a/datumaro/datumaro/components/importers/cvat.py +++ b/datumaro/datumaro/components/importers/cvat.py @@ -40,7 +40,7 @@ class CvatImporter: project.add_source(subset_name, { 'url': subset_path, 'format': self.EXTRACTOR_NAME, - 'options': extra_params, + 'options': dict(extra_params), }) return project diff --git a/datumaro/datumaro/components/importers/datumaro.py b/datumaro/datumaro/components/importers/datumaro.py index 9c2a162b..828208d8 100644 --- a/datumaro/datumaro/components/importers/datumaro.py +++ b/datumaro/datumaro/components/importers/datumaro.py @@ -40,7 +40,7 @@ class DatumaroImporter: project.add_source(subset_name, { 'url': subset_path, 'format': self.EXTRACTOR_NAME, - 'options': extra_params, + 'options': dict(extra_params), }) return project diff --git a/datumaro/datumaro/components/importers/ms_coco.py b/datumaro/datumaro/components/importers/ms_coco.py index e7a0d26c..cb0fb838 100644 --- a/datumaro/datumaro/components/importers/ms_coco.py +++ b/datumaro/datumaro/components/importers/ms_coco.py @@ -37,7 +37,7 @@ class CocoImporter: project.add_source(source_name, { 'url': ann_file, 'format': self._COCO_EXTRACTORS[ann_type], - 'options': extra_params, + 'options': dict(extra_params), }) return project diff --git a/datumaro/datumaro/components/importers/tfrecord.py b/datumaro/datumaro/components/importers/tfrecord.py index c1506211..368c3d0f 100644 --- a/datumaro/datumaro/components/importers/tfrecord.py +++ b/datumaro/datumaro/components/importers/tfrecord.py @@ -35,7 +35,7 @@ class DetectionApiImporter: project.add_source(subset_name, { 'url': subset_path, 'format': self.EXTRACTOR_NAME, - 'options': extra_params, + 'options': dict(extra_params), }) return project diff --git a/datumaro/datumaro/components/importers/voc.py b/datumaro/datumaro/components/importers/voc.py index e7132789..bc0409df 100644 --- a/datumaro/datumaro/components/importers/voc.py +++ b/datumaro/datumaro/components/importers/voc.py @@ -31,7 +31,7 @@ class VocImporter: project.add_source(task.name, { 'url': path, 'format': extractor_type, - 'options': extra_params, + 'options': dict(extra_params), }) if len(project.config.sources) == 0: @@ -69,7 +69,7 @@ class VocResultsImporter: project.add_source(task_name, { 'url': task_dir, 'format': extractor_type, - 'options': extra_params, + 'options': dict(extra_params), }) if len(project.config.sources) == 0: diff --git a/datumaro/datumaro/components/importers/yolo.py b/datumaro/datumaro/components/importers/yolo.py index 2a22117e..df8f7396 100644 --- a/datumaro/datumaro/components/importers/yolo.py +++ b/datumaro/datumaro/components/importers/yolo.py @@ -28,7 +28,7 @@ class YoloImporter: project.add_source(source_name, { 'url': config_path, 'format': 'yolo', - 'options': extra_params, + 'options': dict(extra_params), }) return project \ No newline at end of file diff --git a/datumaro/datumaro/components/project.py b/datumaro/datumaro/components/project.py index e03aad63..6fc16c15 100644 --- a/datumaro/datumaro/components/project.py +++ b/datumaro/datumaro/components/project.py @@ -97,7 +97,7 @@ class GitWrapper: def __init__(self, config=None): self.repo = None - if config is not None: + if config is not None and osp.isdir(config.project_dir): self.init(config.project_dir) @staticmethod @@ -335,7 +335,7 @@ class ProjectDataset(Extractor): own_source = None own_source_dir = osp.join(config.project_dir, config.dataset_dir) - if osp.isdir(own_source_dir): + if osp.isdir(config.project_dir) and osp.isdir(own_source_dir): log.disable(log.INFO) own_source = env.make_importer(DEFAULT_FORMAT)(own_source_dir) \ .make_dataset() diff --git a/datumaro/tests/test_project.py b/datumaro/tests/test_project.py index 7f67e1d9..a66668fd 100644 --- a/datumaro/tests/test_project.py +++ b/datumaro/tests/test_project.py @@ -273,6 +273,19 @@ class ProjectTest(TestCase): self.assertEqual(5, len(dataset)) + def test_can_save_and_load_own_dataset(self): + with TestDir() as test_dir: + src_project = Project() + src_dataset = src_project.make_dataset() + item = DatasetItem(id=1) + src_dataset.put(item) + src_dataset.save(test_dir.path) + + loaded_project = Project.load(test_dir.path) + loaded_dataset = loaded_project.make_dataset() + + self.assertEqual(list(src_dataset), list(loaded_dataset)) + def test_project_own_dataset_can_be_modified(self): project = Project() dataset = project.make_dataset()