diff --git a/datumaro/datumaro/components/extractor.py b/datumaro/datumaro/components/extractor.py index dc7867d0..d88d7c45 100644 --- a/datumaro/datumaro/components/extractor.py +++ b/datumaro/datumaro/components/extractor.py @@ -741,7 +741,22 @@ DEFAULT_SUBSET_NAME = 'default' class SourceExtractor(Extractor): - pass + def __init__(self, length=None, subset=None): + super().__init__(length=length) + + if subset == DEFAULT_SUBSET_NAME: + subset = None + self._subset = subset + + def subsets(self): + if self._subset: + return [self._subset] + return None + + def get_subset(self, name): + if name != self._subset: + return None + return self class Importer: @classmethod diff --git a/datumaro/datumaro/plugins/coco_format/extractor.py b/datumaro/datumaro/plugins/coco_format/extractor.py index 730c3835..250404c6 100644 --- a/datumaro/datumaro/plugins/coco_format/extractor.py +++ b/datumaro/datumaro/plugins/coco_format/extractor.py @@ -22,19 +22,22 @@ from .format import CocoTask, CocoPath class _CocoExtractor(SourceExtractor): def __init__(self, path, task, merge_instance_polygons=False): - super().__init__() - - assert osp.isfile(path) - rootpath = path.rsplit(CocoPath.ANNOTATIONS_DIR, maxsplit=1)[0] - self._path = rootpath + assert osp.isfile(path), path + + subset = osp.splitext(osp.basename(path))[0].rsplit('_', maxsplit=1)[1] + super().__init__(subset=subset) + + rootpath = '' + if path.endswith(osp.join(CocoPath.ANNOTATIONS_DIR, osp.basename(path))): + rootpath = path.rsplit(CocoPath.ANNOTATIONS_DIR, maxsplit=1)[0] + images_dir = '' + if rootpath and osp.isdir(osp.join(rootpath, CocoPath.IMAGES_DIR)): + images_dir = osp.join(rootpath, CocoPath.IMAGES_DIR) + if osp.isdir(osp.join(images_dir, subset or DEFAULT_SUBSET_NAME)): + images_dir = osp.join(images_dir, subset or DEFAULT_SUBSET_NAME) + self._images_dir = images_dir self._task = task - subset = osp.splitext(osp.basename(path))[0] \ - .rsplit('_', maxsplit=1)[1] - if subset == DEFAULT_SUBSET_NAME: - subset = None - self._subset = subset - self._merge_instance_polygons = merge_instance_polygons loader = self._make_subset_loader(path) @@ -51,16 +54,6 @@ class _CocoExtractor(SourceExtractor): def __len__(self): return len(self._items) - def subsets(self): - if self._subset: - return [self._subset] - return None - - def get_subset(self, name): - if name != self._subset: - return None - return self - @staticmethod def _make_subset_loader(path): # COCO API has an 'unclosed file' warning @@ -117,9 +110,7 @@ class _CocoExtractor(SourceExtractor): for img_id in loader.getImgIds(): image_info = loader.loadImgs(img_id)[0] - image_path = self._find_image(image_info['file_name']) - if not image_path: - image_path = image_info['file_name'] + image_path = osp.join(self._images_dir, image_info['file_name']) image_size = (image_info.get('height'), image_info.get('width')) if all(image_size): image_size = (int(image_size[0]), int(image_size[1])) @@ -232,33 +223,27 @@ class _CocoExtractor(SourceExtractor): return parsed_annotations - def _find_image(self, file_name): - images_dir = osp.join(self._path, CocoPath.IMAGES_DIR) - search_paths = [ - osp.join(images_dir, file_name), - osp.join(images_dir, self._subset or DEFAULT_SUBSET_NAME, file_name), - ] - for image_path in search_paths: - if osp.exists(image_path): - return image_path - return None - class CocoImageInfoExtractor(_CocoExtractor): def __init__(self, path, **kwargs): - super().__init__(path, task=CocoTask.image_info, **kwargs) + kwargs['task'] = CocoTask.image_info + super().__init__(path, **kwargs) class CocoCaptionsExtractor(_CocoExtractor): def __init__(self, path, **kwargs): - super().__init__(path, task=CocoTask.captions, **kwargs) + kwargs['task'] = CocoTask.captions + super().__init__(path, **kwargs) class CocoInstancesExtractor(_CocoExtractor): def __init__(self, path, **kwargs): - super().__init__(path, task=CocoTask.instances, **kwargs) + kwargs['task'] = CocoTask.instances + super().__init__(path, **kwargs) class CocoPersonKeypointsExtractor(_CocoExtractor): def __init__(self, path, **kwargs): - super().__init__(path, task=CocoTask.person_keypoints, **kwargs) + kwargs['task'] = CocoTask.person_keypoints + super().__init__(path, **kwargs) class CocoLabelsExtractor(_CocoExtractor): def __init__(self, path, **kwargs): - super().__init__(path, task=CocoTask.labels, **kwargs) \ No newline at end of file + kwargs['task'] = CocoTask.labels + super().__init__(path, **kwargs) \ No newline at end of file diff --git a/datumaro/datumaro/plugins/cvat_format/extractor.py b/datumaro/datumaro/plugins/cvat_format/extractor.py index 014aa90f..0a181d74 100644 --- a/datumaro/datumaro/plugins/cvat_format/extractor.py +++ b/datumaro/datumaro/plugins/cvat_format/extractor.py @@ -7,8 +7,7 @@ from collections import OrderedDict import os.path as osp from defusedxml import ElementTree -from datumaro.components.extractor import (SourceExtractor, - DEFAULT_SUBSET_NAME, DatasetItem, +from datumaro.components.extractor import (SourceExtractor, DatasetItem, AnnotationType, Points, Polygon, PolyLine, Bbox, Label, LabelCategories ) @@ -21,9 +20,7 @@ class CvatExtractor(SourceExtractor): _SUPPORTED_SHAPES = ('box', 'polygon', 'polyline', 'points') def __init__(self, path): - super().__init__() - - assert osp.isfile(path) + assert osp.isfile(path), path rootpath = '' if path.endswith(osp.join(CvatPath.ANNOTATIONS_DIR, osp.basename(path))): rootpath = path.rsplit(CvatPath.ANNOTATIONS_DIR, maxsplit=1)[0] @@ -33,10 +30,7 @@ class CvatExtractor(SourceExtractor): self._images_dir = images_dir self._path = path - subset = osp.splitext(osp.basename(path))[0] - if subset == DEFAULT_SUBSET_NAME: - subset = None - self._subset = subset + super().__init__(subset=osp.splitext(osp.basename(path))[0]) items, categories = self._parse(path) self._items = self._load_items(items) @@ -52,16 +46,6 @@ class CvatExtractor(SourceExtractor): def __len__(self): return len(self._items) - def subsets(self): - if self._subset: - return [self._subset] - return None - - def get_subset(self, name): - if name != self._subset: - return None - return self - @classmethod def _parse(cls, path): context = ElementTree.iterparse(path, events=("start", "end")) @@ -342,14 +326,8 @@ class CvatExtractor(SourceExtractor): def _find_image(self, file_name): search_paths = [] if self._images_dir: - search_paths += [ - osp.join(self._images_dir, file_name), - osp.join(self._images_dir, self._subset or DEFAULT_SUBSET_NAME, - file_name), - ] - search_paths += [ - osp.join(osp.dirname(self._path), file_name) - ] + search_paths += [ osp.join(self._images_dir, file_name) ] + search_paths += [ osp.join(osp.dirname(self._path), file_name) ] for image_path in search_paths: if osp.isfile(image_path): return image_path diff --git a/datumaro/datumaro/plugins/datumaro_format/extractor.py b/datumaro/datumaro/plugins/datumaro_format/extractor.py index 4be7a778..4a19565c 100644 --- a/datumaro/datumaro/plugins/datumaro_format/extractor.py +++ b/datumaro/datumaro/plugins/datumaro_format/extractor.py @@ -6,8 +6,7 @@ import json import os.path as osp -from datumaro.components.extractor import (SourceExtractor, - DEFAULT_SUBSET_NAME, DatasetItem, +from datumaro.components.extractor import (SourceExtractor, DatasetItem, AnnotationType, Label, RleMask, Points, Polygon, PolyLine, Bbox, Caption, LabelCategories, MaskCategories, PointsCategories ) @@ -18,16 +17,16 @@ from .format import DatumaroPath class DatumaroExtractor(SourceExtractor): def __init__(self, path): - super().__init__() + assert osp.isfile(path), path + rootpath = '' + if path.endswith(osp.join(DatumaroPath.ANNOTATIONS_DIR, osp.basename(path))): + rootpath = path.rsplit(DatumaroPath.ANNOTATIONS_DIR, maxsplit=1)[0] + images_dir = '' + if rootpath and osp.isdir(osp.join(rootpath, DatumaroPath.IMAGES_DIR)): + images_dir = osp.join(rootpath, DatumaroPath.IMAGES_DIR) + self._images_dir = images_dir - assert osp.isfile(path) - rootpath = path.rsplit(DatumaroPath.ANNOTATIONS_DIR, maxsplit=1)[0] - self._path = rootpath - - subset_name = osp.splitext(osp.basename(path))[0] - if subset_name == DEFAULT_SUBSET_NAME: - subset_name = None - self._subset_name = subset_name + super().__init__(subset=osp.splitext(osp.basename(path))[0]) with open(path, 'r') as f: parsed_anns = json.load(f) @@ -44,16 +43,6 @@ class DatumaroExtractor(SourceExtractor): def __len__(self): return len(self._items) - def subsets(self): - if self._subset_name: - return [self._subset_name] - return None - - def get_subset(self, name): - if name != self._subset_name: - return None - return self - @staticmethod def _load_categories(parsed): categories = {} @@ -95,13 +84,13 @@ class DatumaroExtractor(SourceExtractor): image = None image_info = item_desc.get('image', {}) if image_info: - image_path = osp.join(self._path, DatumaroPath.IMAGES_DIR, + image_path = osp.join(self._images_dir, image_info.get('path', '')) # relative or absolute fits image = Image(path=image_path, size=image_info.get('size')) annotations = self._load_annotations(item_desc) - item = DatasetItem(id=item_id, subset=self._subset_name, + item = DatasetItem(id=item_id, subset=self._subset, annotations=annotations, image=image) items.append(item) diff --git a/datumaro/datumaro/plugins/image_dir.py b/datumaro/datumaro/plugins/image_dir.py index c719c546..5f3a1884 100644 --- a/datumaro/datumaro/plugins/image_dir.py +++ b/datumaro/datumaro/plugins/image_dir.py @@ -38,7 +38,7 @@ class ImageDirExtractor(SourceExtractor): def __init__(self, url): super().__init__() - assert osp.isdir(url) + assert osp.isdir(url), url items = [] for name in os.listdir(url): @@ -52,8 +52,6 @@ class ImageDirExtractor(SourceExtractor): items = OrderedDict(items) self._items = items - self._subsets = None - def __iter__(self): for item in self._items.values(): yield item @@ -61,9 +59,6 @@ class ImageDirExtractor(SourceExtractor): def __len__(self): return len(self._items) - def subsets(self): - return self._subsets - def get(self, item_id, subset=None, path=None): if path or subset: raise KeyError() diff --git a/datumaro/datumaro/plugins/labelme_format.py b/datumaro/datumaro/plugins/labelme_format.py index 41069da9..d2512ad7 100644 --- a/datumaro/datumaro/plugins/labelme_format.py +++ b/datumaro/datumaro/plugins/labelme_format.py @@ -26,12 +26,8 @@ class LabelMePath: class LabelMeExtractor(SourceExtractor): def __init__(self, path, subset_name=None): - super().__init__() - - assert osp.isdir(path) - self._rootdir = path - - self._subset = subset_name + assert osp.isdir(path), path + super().__init__(subset=subset_name) items, categories = self._parse(path) self._categories = categories @@ -47,16 +43,6 @@ class LabelMeExtractor(SourceExtractor): def __len__(self): return len(self._items) - def subsets(self): - if self._subset: - return [self._subset] - return None - - def get_subset(self, name): - if name != self._subset: - return None - return self - def _parse(self, path): categories = { AnnotationType.label: LabelCategories(attributes={ diff --git a/datumaro/datumaro/plugins/mot_format.py b/datumaro/datumaro/plugins/mot_format.py index 18d3695b..a586645d 100644 --- a/datumaro/datumaro/plugins/mot_format.py +++ b/datumaro/datumaro/plugins/mot_format.py @@ -66,9 +66,7 @@ class MotSeqExtractor(SourceExtractor): super().__init__() assert osp.isfile(path) - self._path = path seq_root = osp.dirname(osp.dirname(path)) - self._image_dir = '' if osp.isdir(osp.join(seq_root, MotPath.IMAGE_DIR)): self._image_dir = osp.join(seq_root, MotPath.IMAGE_DIR) @@ -91,8 +89,6 @@ class MotSeqExtractor(SourceExtractor): is_gt = True self._is_gt = is_gt - self._subset = None - if labels is None: if osp.isfile(osp.join(seq_root, MotPath.LABELS_FILE)): labels = osp.join(seq_root, MotPath.LABELS_FILE) @@ -117,16 +113,6 @@ class MotSeqExtractor(SourceExtractor): def __len__(self): return len(self._items) - def subsets(self): - if self._subset: - return [self._subset] - return None - - def get_subset(self, name): - if name != self._subset: - return None - return self - @staticmethod def _parse_labels(path): with open(path, encoding='utf-8') as labels_file: diff --git a/datumaro/datumaro/plugins/tf_detection_api_format/extractor.py b/datumaro/datumaro/plugins/tf_detection_api_format/extractor.py index 567392dd..0f4c474b 100644 --- a/datumaro/datumaro/plugins/tf_detection_api_format/extractor.py +++ b/datumaro/datumaro/plugins/tf_detection_api_format/extractor.py @@ -8,8 +8,7 @@ import numpy as np import os.path as osp import re -from datumaro.components.extractor import (SourceExtractor, - DEFAULT_SUBSET_NAME, DatasetItem, +from datumaro.components.extractor import (SourceExtractor, DatasetItem, AnnotationType, Bbox, Mask, LabelCategories ) from datumaro.util.image import Image, decode_image, lazy_image @@ -24,9 +23,7 @@ def clamp(value, _min, _max): class TfDetectionApiExtractor(SourceExtractor): def __init__(self, path): - super().__init__() - - assert osp.isfile(path) + assert osp.isfile(path), path images_dir = '' root_dir = osp.dirname(osp.abspath(path)) if osp.basename(root_dir) == DetectionApiPath.ANNOTATIONS_DIR: @@ -35,12 +32,9 @@ class TfDetectionApiExtractor(SourceExtractor): if not osp.isdir(images_dir): images_dir = '' - subset_name = osp.splitext(osp.basename(path))[0] - if subset_name == DEFAULT_SUBSET_NAME: - subset_name = None - self._subset_name = subset_name + super().__init__(subset=osp.splitext(osp.basename(path))[0]) - items, labels = self._parse_tfrecord_file(path, subset_name, images_dir) + items, labels = self._parse_tfrecord_file(path, self._subset, images_dir) self._items = items self._categories = self._load_categories(labels) @@ -54,16 +48,6 @@ class TfDetectionApiExtractor(SourceExtractor): def __len__(self): return len(self._items) - def subsets(self): - if self._subset_name: - return [self._subset_name] - return None - - def get_subset(self, name): - if name != self._subset_name: - return None - return self - @staticmethod def _load_categories(labels): label_categories = LabelCategories() @@ -92,7 +76,7 @@ class TfDetectionApiExtractor(SourceExtractor): return labelmap @classmethod - def _parse_tfrecord_file(cls, filepath, subset_name, images_dir): + def _parse_tfrecord_file(cls, filepath, subset, images_dir): dataset = tf.data.TFRecordDataset(filepath) features = { 'image/filename': tf.io.FixedLenFeature([], tf.string), @@ -203,7 +187,7 @@ class TfDetectionApiExtractor(SourceExtractor): if image_params: image = Image(**image_params, size=image_size) - dataset_items.append(DatasetItem(id=item_id, subset=subset_name, + dataset_items.append(DatasetItem(id=item_id, subset=subset, image=image, annotations=annotations)) return dataset_items, dataset_labels diff --git a/datumaro/datumaro/plugins/voc_format/extractor.py b/datumaro/datumaro/plugins/voc_format/extractor.py index 29aaad6d..96312185 100644 --- a/datumaro/datumaro/plugins/voc_format/extractor.py +++ b/datumaro/datumaro/plugins/voc_format/extractor.py @@ -9,8 +9,7 @@ import numpy as np import os.path as osp from defusedxml import ElementTree -from datumaro.components.extractor import (SourceExtractor, - DEFAULT_SUBSET_NAME, DatasetItem, +from datumaro.components.extractor import (SourceExtractor, DatasetItem, AnnotationType, Label, Mask, Bbox, CompiledMask ) from datumaro.util import dir_items @@ -26,16 +25,11 @@ _inverse_inst_colormap = invert_colormap(VocInstColormap) class _VocExtractor(SourceExtractor): def __init__(self, path): - super().__init__() - assert osp.isfile(path), path self._path = path self._dataset_dir = osp.dirname(osp.dirname(osp.dirname(path))) - subset = osp.splitext(osp.basename(path))[0] - if subset == DEFAULT_SUBSET_NAME: - subset = None - self._subset = subset + super().__init__(subset=osp.splitext(osp.basename(path))[0]) self._categories = self._load_categories(self._dataset_dir) log.debug("Loaded labels: %s", ', '.join("'%s'" % l.name @@ -48,16 +42,6 @@ class _VocExtractor(SourceExtractor): def __len__(self): return len(self._items) - def subsets(self): - if self._subset: - return [self._subset] - return None - - def get_subset(self, name): - if name != self._subset: - return None - return self - def _get_label_id(self, label): label_id, _ = self._categories[AnnotationType.label].find(label) assert label_id is not None, label