[Datumaro] Extract common extractor functionality (#1319)

* Extract common extractor functionality

* Simplify coco extractor

* Fix tfrecord
main
zhiltsov-max 6 years ago committed by GitHub
parent 3f4d6fc8d3
commit 6a4ccdea1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -741,7 +741,22 @@ DEFAULT_SUBSET_NAME = 'default'
class SourceExtractor(Extractor):
pass
def __init__(self, length=None, subset=None):
super().__init__(length=length)
if subset == DEFAULT_SUBSET_NAME:
subset = None
self._subset = subset
def subsets(self):
if self._subset:
return [self._subset]
return None
def get_subset(self, name):
if name != self._subset:
return None
return self
class Importer:
@classmethod

@ -22,19 +22,22 @@ from .format import CocoTask, CocoPath
class _CocoExtractor(SourceExtractor):
def __init__(self, path, task, merge_instance_polygons=False):
super().__init__()
assert osp.isfile(path)
rootpath = path.rsplit(CocoPath.ANNOTATIONS_DIR, maxsplit=1)[0]
self._path = rootpath
assert osp.isfile(path), path
subset = osp.splitext(osp.basename(path))[0].rsplit('_', maxsplit=1)[1]
super().__init__(subset=subset)
rootpath = ''
if path.endswith(osp.join(CocoPath.ANNOTATIONS_DIR, osp.basename(path))):
rootpath = path.rsplit(CocoPath.ANNOTATIONS_DIR, maxsplit=1)[0]
images_dir = ''
if rootpath and osp.isdir(osp.join(rootpath, CocoPath.IMAGES_DIR)):
images_dir = osp.join(rootpath, CocoPath.IMAGES_DIR)
if osp.isdir(osp.join(images_dir, subset or DEFAULT_SUBSET_NAME)):
images_dir = osp.join(images_dir, subset or DEFAULT_SUBSET_NAME)
self._images_dir = images_dir
self._task = task
subset = osp.splitext(osp.basename(path))[0] \
.rsplit('_', maxsplit=1)[1]
if subset == DEFAULT_SUBSET_NAME:
subset = None
self._subset = subset
self._merge_instance_polygons = merge_instance_polygons
loader = self._make_subset_loader(path)
@ -51,16 +54,6 @@ class _CocoExtractor(SourceExtractor):
def __len__(self):
return len(self._items)
def subsets(self):
if self._subset:
return [self._subset]
return None
def get_subset(self, name):
if name != self._subset:
return None
return self
@staticmethod
def _make_subset_loader(path):
# COCO API has an 'unclosed file' warning
@ -117,9 +110,7 @@ class _CocoExtractor(SourceExtractor):
for img_id in loader.getImgIds():
image_info = loader.loadImgs(img_id)[0]
image_path = self._find_image(image_info['file_name'])
if not image_path:
image_path = image_info['file_name']
image_path = osp.join(self._images_dir, image_info['file_name'])
image_size = (image_info.get('height'), image_info.get('width'))
if all(image_size):
image_size = (int(image_size[0]), int(image_size[1]))
@ -232,33 +223,27 @@ class _CocoExtractor(SourceExtractor):
return parsed_annotations
def _find_image(self, file_name):
images_dir = osp.join(self._path, CocoPath.IMAGES_DIR)
search_paths = [
osp.join(images_dir, file_name),
osp.join(images_dir, self._subset or DEFAULT_SUBSET_NAME, file_name),
]
for image_path in search_paths:
if osp.exists(image_path):
return image_path
return None
class CocoImageInfoExtractor(_CocoExtractor):
def __init__(self, path, **kwargs):
super().__init__(path, task=CocoTask.image_info, **kwargs)
kwargs['task'] = CocoTask.image_info
super().__init__(path, **kwargs)
class CocoCaptionsExtractor(_CocoExtractor):
def __init__(self, path, **kwargs):
super().__init__(path, task=CocoTask.captions, **kwargs)
kwargs['task'] = CocoTask.captions
super().__init__(path, **kwargs)
class CocoInstancesExtractor(_CocoExtractor):
def __init__(self, path, **kwargs):
super().__init__(path, task=CocoTask.instances, **kwargs)
kwargs['task'] = CocoTask.instances
super().__init__(path, **kwargs)
class CocoPersonKeypointsExtractor(_CocoExtractor):
def __init__(self, path, **kwargs):
super().__init__(path, task=CocoTask.person_keypoints, **kwargs)
kwargs['task'] = CocoTask.person_keypoints
super().__init__(path, **kwargs)
class CocoLabelsExtractor(_CocoExtractor):
def __init__(self, path, **kwargs):
super().__init__(path, task=CocoTask.labels, **kwargs)
kwargs['task'] = CocoTask.labels
super().__init__(path, **kwargs)

@ -7,8 +7,7 @@ from collections import OrderedDict
import os.path as osp
from defusedxml import ElementTree
from datumaro.components.extractor import (SourceExtractor,
DEFAULT_SUBSET_NAME, DatasetItem,
from datumaro.components.extractor import (SourceExtractor, DatasetItem,
AnnotationType, Points, Polygon, PolyLine, Bbox, Label,
LabelCategories
)
@ -21,9 +20,7 @@ class CvatExtractor(SourceExtractor):
_SUPPORTED_SHAPES = ('box', 'polygon', 'polyline', 'points')
def __init__(self, path):
super().__init__()
assert osp.isfile(path)
assert osp.isfile(path), path
rootpath = ''
if path.endswith(osp.join(CvatPath.ANNOTATIONS_DIR, osp.basename(path))):
rootpath = path.rsplit(CvatPath.ANNOTATIONS_DIR, maxsplit=1)[0]
@ -33,10 +30,7 @@ class CvatExtractor(SourceExtractor):
self._images_dir = images_dir
self._path = path
subset = osp.splitext(osp.basename(path))[0]
if subset == DEFAULT_SUBSET_NAME:
subset = None
self._subset = subset
super().__init__(subset=osp.splitext(osp.basename(path))[0])
items, categories = self._parse(path)
self._items = self._load_items(items)
@ -52,16 +46,6 @@ class CvatExtractor(SourceExtractor):
def __len__(self):
return len(self._items)
def subsets(self):
if self._subset:
return [self._subset]
return None
def get_subset(self, name):
if name != self._subset:
return None
return self
@classmethod
def _parse(cls, path):
context = ElementTree.iterparse(path, events=("start", "end"))
@ -342,14 +326,8 @@ class CvatExtractor(SourceExtractor):
def _find_image(self, file_name):
search_paths = []
if self._images_dir:
search_paths += [
osp.join(self._images_dir, file_name),
osp.join(self._images_dir, self._subset or DEFAULT_SUBSET_NAME,
file_name),
]
search_paths += [
osp.join(osp.dirname(self._path), file_name)
]
search_paths += [ osp.join(self._images_dir, file_name) ]
search_paths += [ osp.join(osp.dirname(self._path), file_name) ]
for image_path in search_paths:
if osp.isfile(image_path):
return image_path

@ -6,8 +6,7 @@
import json
import os.path as osp
from datumaro.components.extractor import (SourceExtractor,
DEFAULT_SUBSET_NAME, DatasetItem,
from datumaro.components.extractor import (SourceExtractor, DatasetItem,
AnnotationType, Label, RleMask, Points, Polygon, PolyLine, Bbox, Caption,
LabelCategories, MaskCategories, PointsCategories
)
@ -18,16 +17,16 @@ from .format import DatumaroPath
class DatumaroExtractor(SourceExtractor):
def __init__(self, path):
super().__init__()
assert osp.isfile(path), path
rootpath = ''
if path.endswith(osp.join(DatumaroPath.ANNOTATIONS_DIR, osp.basename(path))):
rootpath = path.rsplit(DatumaroPath.ANNOTATIONS_DIR, maxsplit=1)[0]
images_dir = ''
if rootpath and osp.isdir(osp.join(rootpath, DatumaroPath.IMAGES_DIR)):
images_dir = osp.join(rootpath, DatumaroPath.IMAGES_DIR)
self._images_dir = images_dir
assert osp.isfile(path)
rootpath = path.rsplit(DatumaroPath.ANNOTATIONS_DIR, maxsplit=1)[0]
self._path = rootpath
subset_name = osp.splitext(osp.basename(path))[0]
if subset_name == DEFAULT_SUBSET_NAME:
subset_name = None
self._subset_name = subset_name
super().__init__(subset=osp.splitext(osp.basename(path))[0])
with open(path, 'r') as f:
parsed_anns = json.load(f)
@ -44,16 +43,6 @@ class DatumaroExtractor(SourceExtractor):
def __len__(self):
return len(self._items)
def subsets(self):
if self._subset_name:
return [self._subset_name]
return None
def get_subset(self, name):
if name != self._subset_name:
return None
return self
@staticmethod
def _load_categories(parsed):
categories = {}
@ -95,13 +84,13 @@ class DatumaroExtractor(SourceExtractor):
image = None
image_info = item_desc.get('image', {})
if image_info:
image_path = osp.join(self._path, DatumaroPath.IMAGES_DIR,
image_path = osp.join(self._images_dir,
image_info.get('path', '')) # relative or absolute fits
image = Image(path=image_path, size=image_info.get('size'))
annotations = self._load_annotations(item_desc)
item = DatasetItem(id=item_id, subset=self._subset_name,
item = DatasetItem(id=item_id, subset=self._subset,
annotations=annotations, image=image)
items.append(item)

@ -38,7 +38,7 @@ class ImageDirExtractor(SourceExtractor):
def __init__(self, url):
super().__init__()
assert osp.isdir(url)
assert osp.isdir(url), url
items = []
for name in os.listdir(url):
@ -52,8 +52,6 @@ class ImageDirExtractor(SourceExtractor):
items = OrderedDict(items)
self._items = items
self._subsets = None
def __iter__(self):
for item in self._items.values():
yield item
@ -61,9 +59,6 @@ class ImageDirExtractor(SourceExtractor):
def __len__(self):
return len(self._items)
def subsets(self):
return self._subsets
def get(self, item_id, subset=None, path=None):
if path or subset:
raise KeyError()

@ -26,12 +26,8 @@ class LabelMePath:
class LabelMeExtractor(SourceExtractor):
def __init__(self, path, subset_name=None):
super().__init__()
assert osp.isdir(path)
self._rootdir = path
self._subset = subset_name
assert osp.isdir(path), path
super().__init__(subset=subset_name)
items, categories = self._parse(path)
self._categories = categories
@ -47,16 +43,6 @@ class LabelMeExtractor(SourceExtractor):
def __len__(self):
return len(self._items)
def subsets(self):
if self._subset:
return [self._subset]
return None
def get_subset(self, name):
if name != self._subset:
return None
return self
def _parse(self, path):
categories = {
AnnotationType.label: LabelCategories(attributes={

@ -66,9 +66,7 @@ class MotSeqExtractor(SourceExtractor):
super().__init__()
assert osp.isfile(path)
self._path = path
seq_root = osp.dirname(osp.dirname(path))
self._image_dir = ''
if osp.isdir(osp.join(seq_root, MotPath.IMAGE_DIR)):
self._image_dir = osp.join(seq_root, MotPath.IMAGE_DIR)
@ -91,8 +89,6 @@ class MotSeqExtractor(SourceExtractor):
is_gt = True
self._is_gt = is_gt
self._subset = None
if labels is None:
if osp.isfile(osp.join(seq_root, MotPath.LABELS_FILE)):
labels = osp.join(seq_root, MotPath.LABELS_FILE)
@ -117,16 +113,6 @@ class MotSeqExtractor(SourceExtractor):
def __len__(self):
return len(self._items)
def subsets(self):
if self._subset:
return [self._subset]
return None
def get_subset(self, name):
if name != self._subset:
return None
return self
@staticmethod
def _parse_labels(path):
with open(path, encoding='utf-8') as labels_file:

@ -8,8 +8,7 @@ import numpy as np
import os.path as osp
import re
from datumaro.components.extractor import (SourceExtractor,
DEFAULT_SUBSET_NAME, DatasetItem,
from datumaro.components.extractor import (SourceExtractor, DatasetItem,
AnnotationType, Bbox, Mask, LabelCategories
)
from datumaro.util.image import Image, decode_image, lazy_image
@ -24,9 +23,7 @@ def clamp(value, _min, _max):
class TfDetectionApiExtractor(SourceExtractor):
def __init__(self, path):
super().__init__()
assert osp.isfile(path)
assert osp.isfile(path), path
images_dir = ''
root_dir = osp.dirname(osp.abspath(path))
if osp.basename(root_dir) == DetectionApiPath.ANNOTATIONS_DIR:
@ -35,12 +32,9 @@ class TfDetectionApiExtractor(SourceExtractor):
if not osp.isdir(images_dir):
images_dir = ''
subset_name = osp.splitext(osp.basename(path))[0]
if subset_name == DEFAULT_SUBSET_NAME:
subset_name = None
self._subset_name = subset_name
super().__init__(subset=osp.splitext(osp.basename(path))[0])
items, labels = self._parse_tfrecord_file(path, subset_name, images_dir)
items, labels = self._parse_tfrecord_file(path, self._subset, images_dir)
self._items = items
self._categories = self._load_categories(labels)
@ -54,16 +48,6 @@ class TfDetectionApiExtractor(SourceExtractor):
def __len__(self):
return len(self._items)
def subsets(self):
if self._subset_name:
return [self._subset_name]
return None
def get_subset(self, name):
if name != self._subset_name:
return None
return self
@staticmethod
def _load_categories(labels):
label_categories = LabelCategories()
@ -92,7 +76,7 @@ class TfDetectionApiExtractor(SourceExtractor):
return labelmap
@classmethod
def _parse_tfrecord_file(cls, filepath, subset_name, images_dir):
def _parse_tfrecord_file(cls, filepath, subset, images_dir):
dataset = tf.data.TFRecordDataset(filepath)
features = {
'image/filename': tf.io.FixedLenFeature([], tf.string),
@ -203,7 +187,7 @@ class TfDetectionApiExtractor(SourceExtractor):
if image_params:
image = Image(**image_params, size=image_size)
dataset_items.append(DatasetItem(id=item_id, subset=subset_name,
dataset_items.append(DatasetItem(id=item_id, subset=subset,
image=image, annotations=annotations))
return dataset_items, dataset_labels

@ -9,8 +9,7 @@ import numpy as np
import os.path as osp
from defusedxml import ElementTree
from datumaro.components.extractor import (SourceExtractor,
DEFAULT_SUBSET_NAME, DatasetItem,
from datumaro.components.extractor import (SourceExtractor, DatasetItem,
AnnotationType, Label, Mask, Bbox, CompiledMask
)
from datumaro.util import dir_items
@ -26,16 +25,11 @@ _inverse_inst_colormap = invert_colormap(VocInstColormap)
class _VocExtractor(SourceExtractor):
def __init__(self, path):
super().__init__()
assert osp.isfile(path), path
self._path = path
self._dataset_dir = osp.dirname(osp.dirname(osp.dirname(path)))
subset = osp.splitext(osp.basename(path))[0]
if subset == DEFAULT_SUBSET_NAME:
subset = None
self._subset = subset
super().__init__(subset=osp.splitext(osp.basename(path))[0])
self._categories = self._load_categories(self._dataset_dir)
log.debug("Loaded labels: %s", ', '.join("'%s'" % l.name
@ -48,16 +42,6 @@ class _VocExtractor(SourceExtractor):
def __len__(self):
return len(self._items)
def subsets(self):
if self._subset:
return [self._subset]
return None
def get_subset(self, name):
if name != self._subset:
return None
return self
def _get_label_id(self, label):
label_id, _ = self._categories[AnnotationType.label].find(label)
assert label_id is not None, label

Loading…
Cancel
Save