[Datumaro] Extract common extractor functionality (#1319)

* Extract common extractor functionality

* Simplify coco extractor

* Fix tfrecord
main
zhiltsov-max 6 years ago committed by GitHub
parent 3f4d6fc8d3
commit 6a4ccdea1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -741,7 +741,22 @@ DEFAULT_SUBSET_NAME = 'default'
class SourceExtractor(Extractor): class SourceExtractor(Extractor):
pass def __init__(self, length=None, subset=None):
super().__init__(length=length)
if subset == DEFAULT_SUBSET_NAME:
subset = None
self._subset = subset
def subsets(self):
if self._subset:
return [self._subset]
return None
def get_subset(self, name):
if name != self._subset:
return None
return self
class Importer: class Importer:
@classmethod @classmethod

@ -22,19 +22,22 @@ from .format import CocoTask, CocoPath
class _CocoExtractor(SourceExtractor): class _CocoExtractor(SourceExtractor):
def __init__(self, path, task, merge_instance_polygons=False): def __init__(self, path, task, merge_instance_polygons=False):
super().__init__() assert osp.isfile(path), path
assert osp.isfile(path) subset = osp.splitext(osp.basename(path))[0].rsplit('_', maxsplit=1)[1]
rootpath = path.rsplit(CocoPath.ANNOTATIONS_DIR, maxsplit=1)[0] super().__init__(subset=subset)
self._path = rootpath
rootpath = ''
if path.endswith(osp.join(CocoPath.ANNOTATIONS_DIR, osp.basename(path))):
rootpath = path.rsplit(CocoPath.ANNOTATIONS_DIR, maxsplit=1)[0]
images_dir = ''
if rootpath and osp.isdir(osp.join(rootpath, CocoPath.IMAGES_DIR)):
images_dir = osp.join(rootpath, CocoPath.IMAGES_DIR)
if osp.isdir(osp.join(images_dir, subset or DEFAULT_SUBSET_NAME)):
images_dir = osp.join(images_dir, subset or DEFAULT_SUBSET_NAME)
self._images_dir = images_dir
self._task = task self._task = task
subset = osp.splitext(osp.basename(path))[0] \
.rsplit('_', maxsplit=1)[1]
if subset == DEFAULT_SUBSET_NAME:
subset = None
self._subset = subset
self._merge_instance_polygons = merge_instance_polygons self._merge_instance_polygons = merge_instance_polygons
loader = self._make_subset_loader(path) loader = self._make_subset_loader(path)
@ -51,16 +54,6 @@ class _CocoExtractor(SourceExtractor):
def __len__(self): def __len__(self):
return len(self._items) return len(self._items)
def subsets(self):
if self._subset:
return [self._subset]
return None
def get_subset(self, name):
if name != self._subset:
return None
return self
@staticmethod @staticmethod
def _make_subset_loader(path): def _make_subset_loader(path):
# COCO API has an 'unclosed file' warning # COCO API has an 'unclosed file' warning
@ -117,9 +110,7 @@ class _CocoExtractor(SourceExtractor):
for img_id in loader.getImgIds(): for img_id in loader.getImgIds():
image_info = loader.loadImgs(img_id)[0] image_info = loader.loadImgs(img_id)[0]
image_path = self._find_image(image_info['file_name']) image_path = osp.join(self._images_dir, image_info['file_name'])
if not image_path:
image_path = image_info['file_name']
image_size = (image_info.get('height'), image_info.get('width')) image_size = (image_info.get('height'), image_info.get('width'))
if all(image_size): if all(image_size):
image_size = (int(image_size[0]), int(image_size[1])) image_size = (int(image_size[0]), int(image_size[1]))
@ -232,33 +223,27 @@ class _CocoExtractor(SourceExtractor):
return parsed_annotations return parsed_annotations
def _find_image(self, file_name):
images_dir = osp.join(self._path, CocoPath.IMAGES_DIR)
search_paths = [
osp.join(images_dir, file_name),
osp.join(images_dir, self._subset or DEFAULT_SUBSET_NAME, file_name),
]
for image_path in search_paths:
if osp.exists(image_path):
return image_path
return None
class CocoImageInfoExtractor(_CocoExtractor): class CocoImageInfoExtractor(_CocoExtractor):
def __init__(self, path, **kwargs): def __init__(self, path, **kwargs):
super().__init__(path, task=CocoTask.image_info, **kwargs) kwargs['task'] = CocoTask.image_info
super().__init__(path, **kwargs)
class CocoCaptionsExtractor(_CocoExtractor): class CocoCaptionsExtractor(_CocoExtractor):
def __init__(self, path, **kwargs): def __init__(self, path, **kwargs):
super().__init__(path, task=CocoTask.captions, **kwargs) kwargs['task'] = CocoTask.captions
super().__init__(path, **kwargs)
class CocoInstancesExtractor(_CocoExtractor): class CocoInstancesExtractor(_CocoExtractor):
def __init__(self, path, **kwargs): def __init__(self, path, **kwargs):
super().__init__(path, task=CocoTask.instances, **kwargs) kwargs['task'] = CocoTask.instances
super().__init__(path, **kwargs)
class CocoPersonKeypointsExtractor(_CocoExtractor): class CocoPersonKeypointsExtractor(_CocoExtractor):
def __init__(self, path, **kwargs): def __init__(self, path, **kwargs):
super().__init__(path, task=CocoTask.person_keypoints, **kwargs) kwargs['task'] = CocoTask.person_keypoints
super().__init__(path, **kwargs)
class CocoLabelsExtractor(_CocoExtractor): class CocoLabelsExtractor(_CocoExtractor):
def __init__(self, path, **kwargs): def __init__(self, path, **kwargs):
super().__init__(path, task=CocoTask.labels, **kwargs) kwargs['task'] = CocoTask.labels
super().__init__(path, **kwargs)

@ -7,8 +7,7 @@ from collections import OrderedDict
import os.path as osp import os.path as osp
from defusedxml import ElementTree from defusedxml import ElementTree
from datumaro.components.extractor import (SourceExtractor, from datumaro.components.extractor import (SourceExtractor, DatasetItem,
DEFAULT_SUBSET_NAME, DatasetItem,
AnnotationType, Points, Polygon, PolyLine, Bbox, Label, AnnotationType, Points, Polygon, PolyLine, Bbox, Label,
LabelCategories LabelCategories
) )
@ -21,9 +20,7 @@ class CvatExtractor(SourceExtractor):
_SUPPORTED_SHAPES = ('box', 'polygon', 'polyline', 'points') _SUPPORTED_SHAPES = ('box', 'polygon', 'polyline', 'points')
def __init__(self, path): def __init__(self, path):
super().__init__() assert osp.isfile(path), path
assert osp.isfile(path)
rootpath = '' rootpath = ''
if path.endswith(osp.join(CvatPath.ANNOTATIONS_DIR, osp.basename(path))): if path.endswith(osp.join(CvatPath.ANNOTATIONS_DIR, osp.basename(path))):
rootpath = path.rsplit(CvatPath.ANNOTATIONS_DIR, maxsplit=1)[0] rootpath = path.rsplit(CvatPath.ANNOTATIONS_DIR, maxsplit=1)[0]
@ -33,10 +30,7 @@ class CvatExtractor(SourceExtractor):
self._images_dir = images_dir self._images_dir = images_dir
self._path = path self._path = path
subset = osp.splitext(osp.basename(path))[0] super().__init__(subset=osp.splitext(osp.basename(path))[0])
if subset == DEFAULT_SUBSET_NAME:
subset = None
self._subset = subset
items, categories = self._parse(path) items, categories = self._parse(path)
self._items = self._load_items(items) self._items = self._load_items(items)
@ -52,16 +46,6 @@ class CvatExtractor(SourceExtractor):
def __len__(self): def __len__(self):
return len(self._items) return len(self._items)
def subsets(self):
if self._subset:
return [self._subset]
return None
def get_subset(self, name):
if name != self._subset:
return None
return self
@classmethod @classmethod
def _parse(cls, path): def _parse(cls, path):
context = ElementTree.iterparse(path, events=("start", "end")) context = ElementTree.iterparse(path, events=("start", "end"))
@ -342,14 +326,8 @@ class CvatExtractor(SourceExtractor):
def _find_image(self, file_name): def _find_image(self, file_name):
search_paths = [] search_paths = []
if self._images_dir: if self._images_dir:
search_paths += [ search_paths += [ osp.join(self._images_dir, file_name) ]
osp.join(self._images_dir, file_name), search_paths += [ osp.join(osp.dirname(self._path), file_name) ]
osp.join(self._images_dir, self._subset or DEFAULT_SUBSET_NAME,
file_name),
]
search_paths += [
osp.join(osp.dirname(self._path), file_name)
]
for image_path in search_paths: for image_path in search_paths:
if osp.isfile(image_path): if osp.isfile(image_path):
return image_path return image_path

@ -6,8 +6,7 @@
import json import json
import os.path as osp import os.path as osp
from datumaro.components.extractor import (SourceExtractor, from datumaro.components.extractor import (SourceExtractor, DatasetItem,
DEFAULT_SUBSET_NAME, DatasetItem,
AnnotationType, Label, RleMask, Points, Polygon, PolyLine, Bbox, Caption, AnnotationType, Label, RleMask, Points, Polygon, PolyLine, Bbox, Caption,
LabelCategories, MaskCategories, PointsCategories LabelCategories, MaskCategories, PointsCategories
) )
@ -18,16 +17,16 @@ from .format import DatumaroPath
class DatumaroExtractor(SourceExtractor): class DatumaroExtractor(SourceExtractor):
def __init__(self, path): def __init__(self, path):
super().__init__() assert osp.isfile(path), path
rootpath = ''
if path.endswith(osp.join(DatumaroPath.ANNOTATIONS_DIR, osp.basename(path))):
rootpath = path.rsplit(DatumaroPath.ANNOTATIONS_DIR, maxsplit=1)[0]
images_dir = ''
if rootpath and osp.isdir(osp.join(rootpath, DatumaroPath.IMAGES_DIR)):
images_dir = osp.join(rootpath, DatumaroPath.IMAGES_DIR)
self._images_dir = images_dir
assert osp.isfile(path) super().__init__(subset=osp.splitext(osp.basename(path))[0])
rootpath = path.rsplit(DatumaroPath.ANNOTATIONS_DIR, maxsplit=1)[0]
self._path = rootpath
subset_name = osp.splitext(osp.basename(path))[0]
if subset_name == DEFAULT_SUBSET_NAME:
subset_name = None
self._subset_name = subset_name
with open(path, 'r') as f: with open(path, 'r') as f:
parsed_anns = json.load(f) parsed_anns = json.load(f)
@ -44,16 +43,6 @@ class DatumaroExtractor(SourceExtractor):
def __len__(self): def __len__(self):
return len(self._items) return len(self._items)
def subsets(self):
if self._subset_name:
return [self._subset_name]
return None
def get_subset(self, name):
if name != self._subset_name:
return None
return self
@staticmethod @staticmethod
def _load_categories(parsed): def _load_categories(parsed):
categories = {} categories = {}
@ -95,13 +84,13 @@ class DatumaroExtractor(SourceExtractor):
image = None image = None
image_info = item_desc.get('image', {}) image_info = item_desc.get('image', {})
if image_info: if image_info:
image_path = osp.join(self._path, DatumaroPath.IMAGES_DIR, image_path = osp.join(self._images_dir,
image_info.get('path', '')) # relative or absolute fits image_info.get('path', '')) # relative or absolute fits
image = Image(path=image_path, size=image_info.get('size')) image = Image(path=image_path, size=image_info.get('size'))
annotations = self._load_annotations(item_desc) annotations = self._load_annotations(item_desc)
item = DatasetItem(id=item_id, subset=self._subset_name, item = DatasetItem(id=item_id, subset=self._subset,
annotations=annotations, image=image) annotations=annotations, image=image)
items.append(item) items.append(item)

@ -38,7 +38,7 @@ class ImageDirExtractor(SourceExtractor):
def __init__(self, url): def __init__(self, url):
super().__init__() super().__init__()
assert osp.isdir(url) assert osp.isdir(url), url
items = [] items = []
for name in os.listdir(url): for name in os.listdir(url):
@ -52,8 +52,6 @@ class ImageDirExtractor(SourceExtractor):
items = OrderedDict(items) items = OrderedDict(items)
self._items = items self._items = items
self._subsets = None
def __iter__(self): def __iter__(self):
for item in self._items.values(): for item in self._items.values():
yield item yield item
@ -61,9 +59,6 @@ class ImageDirExtractor(SourceExtractor):
def __len__(self): def __len__(self):
return len(self._items) return len(self._items)
def subsets(self):
return self._subsets
def get(self, item_id, subset=None, path=None): def get(self, item_id, subset=None, path=None):
if path or subset: if path or subset:
raise KeyError() raise KeyError()

@ -26,12 +26,8 @@ class LabelMePath:
class LabelMeExtractor(SourceExtractor): class LabelMeExtractor(SourceExtractor):
def __init__(self, path, subset_name=None): def __init__(self, path, subset_name=None):
super().__init__() assert osp.isdir(path), path
super().__init__(subset=subset_name)
assert osp.isdir(path)
self._rootdir = path
self._subset = subset_name
items, categories = self._parse(path) items, categories = self._parse(path)
self._categories = categories self._categories = categories
@ -47,16 +43,6 @@ class LabelMeExtractor(SourceExtractor):
def __len__(self): def __len__(self):
return len(self._items) return len(self._items)
def subsets(self):
if self._subset:
return [self._subset]
return None
def get_subset(self, name):
if name != self._subset:
return None
return self
def _parse(self, path): def _parse(self, path):
categories = { categories = {
AnnotationType.label: LabelCategories(attributes={ AnnotationType.label: LabelCategories(attributes={

@ -66,9 +66,7 @@ class MotSeqExtractor(SourceExtractor):
super().__init__() super().__init__()
assert osp.isfile(path) assert osp.isfile(path)
self._path = path
seq_root = osp.dirname(osp.dirname(path)) seq_root = osp.dirname(osp.dirname(path))
self._image_dir = '' self._image_dir = ''
if osp.isdir(osp.join(seq_root, MotPath.IMAGE_DIR)): if osp.isdir(osp.join(seq_root, MotPath.IMAGE_DIR)):
self._image_dir = osp.join(seq_root, MotPath.IMAGE_DIR) self._image_dir = osp.join(seq_root, MotPath.IMAGE_DIR)
@ -91,8 +89,6 @@ class MotSeqExtractor(SourceExtractor):
is_gt = True is_gt = True
self._is_gt = is_gt self._is_gt = is_gt
self._subset = None
if labels is None: if labels is None:
if osp.isfile(osp.join(seq_root, MotPath.LABELS_FILE)): if osp.isfile(osp.join(seq_root, MotPath.LABELS_FILE)):
labels = osp.join(seq_root, MotPath.LABELS_FILE) labels = osp.join(seq_root, MotPath.LABELS_FILE)
@ -117,16 +113,6 @@ class MotSeqExtractor(SourceExtractor):
def __len__(self): def __len__(self):
return len(self._items) return len(self._items)
def subsets(self):
if self._subset:
return [self._subset]
return None
def get_subset(self, name):
if name != self._subset:
return None
return self
@staticmethod @staticmethod
def _parse_labels(path): def _parse_labels(path):
with open(path, encoding='utf-8') as labels_file: with open(path, encoding='utf-8') as labels_file:

@ -8,8 +8,7 @@ import numpy as np
import os.path as osp import os.path as osp
import re import re
from datumaro.components.extractor import (SourceExtractor, from datumaro.components.extractor import (SourceExtractor, DatasetItem,
DEFAULT_SUBSET_NAME, DatasetItem,
AnnotationType, Bbox, Mask, LabelCategories AnnotationType, Bbox, Mask, LabelCategories
) )
from datumaro.util.image import Image, decode_image, lazy_image from datumaro.util.image import Image, decode_image, lazy_image
@ -24,9 +23,7 @@ def clamp(value, _min, _max):
class TfDetectionApiExtractor(SourceExtractor): class TfDetectionApiExtractor(SourceExtractor):
def __init__(self, path): def __init__(self, path):
super().__init__() assert osp.isfile(path), path
assert osp.isfile(path)
images_dir = '' images_dir = ''
root_dir = osp.dirname(osp.abspath(path)) root_dir = osp.dirname(osp.abspath(path))
if osp.basename(root_dir) == DetectionApiPath.ANNOTATIONS_DIR: if osp.basename(root_dir) == DetectionApiPath.ANNOTATIONS_DIR:
@ -35,12 +32,9 @@ class TfDetectionApiExtractor(SourceExtractor):
if not osp.isdir(images_dir): if not osp.isdir(images_dir):
images_dir = '' images_dir = ''
subset_name = osp.splitext(osp.basename(path))[0] super().__init__(subset=osp.splitext(osp.basename(path))[0])
if subset_name == DEFAULT_SUBSET_NAME:
subset_name = None
self._subset_name = subset_name
items, labels = self._parse_tfrecord_file(path, subset_name, images_dir) items, labels = self._parse_tfrecord_file(path, self._subset, images_dir)
self._items = items self._items = items
self._categories = self._load_categories(labels) self._categories = self._load_categories(labels)
@ -54,16 +48,6 @@ class TfDetectionApiExtractor(SourceExtractor):
def __len__(self): def __len__(self):
return len(self._items) return len(self._items)
def subsets(self):
if self._subset_name:
return [self._subset_name]
return None
def get_subset(self, name):
if name != self._subset_name:
return None
return self
@staticmethod @staticmethod
def _load_categories(labels): def _load_categories(labels):
label_categories = LabelCategories() label_categories = LabelCategories()
@ -92,7 +76,7 @@ class TfDetectionApiExtractor(SourceExtractor):
return labelmap return labelmap
@classmethod @classmethod
def _parse_tfrecord_file(cls, filepath, subset_name, images_dir): def _parse_tfrecord_file(cls, filepath, subset, images_dir):
dataset = tf.data.TFRecordDataset(filepath) dataset = tf.data.TFRecordDataset(filepath)
features = { features = {
'image/filename': tf.io.FixedLenFeature([], tf.string), 'image/filename': tf.io.FixedLenFeature([], tf.string),
@ -203,7 +187,7 @@ class TfDetectionApiExtractor(SourceExtractor):
if image_params: if image_params:
image = Image(**image_params, size=image_size) image = Image(**image_params, size=image_size)
dataset_items.append(DatasetItem(id=item_id, subset=subset_name, dataset_items.append(DatasetItem(id=item_id, subset=subset,
image=image, annotations=annotations)) image=image, annotations=annotations))
return dataset_items, dataset_labels return dataset_items, dataset_labels

@ -9,8 +9,7 @@ import numpy as np
import os.path as osp import os.path as osp
from defusedxml import ElementTree from defusedxml import ElementTree
from datumaro.components.extractor import (SourceExtractor, from datumaro.components.extractor import (SourceExtractor, DatasetItem,
DEFAULT_SUBSET_NAME, DatasetItem,
AnnotationType, Label, Mask, Bbox, CompiledMask AnnotationType, Label, Mask, Bbox, CompiledMask
) )
from datumaro.util import dir_items from datumaro.util import dir_items
@ -26,16 +25,11 @@ _inverse_inst_colormap = invert_colormap(VocInstColormap)
class _VocExtractor(SourceExtractor): class _VocExtractor(SourceExtractor):
def __init__(self, path): def __init__(self, path):
super().__init__()
assert osp.isfile(path), path assert osp.isfile(path), path
self._path = path self._path = path
self._dataset_dir = osp.dirname(osp.dirname(osp.dirname(path))) self._dataset_dir = osp.dirname(osp.dirname(osp.dirname(path)))
subset = osp.splitext(osp.basename(path))[0] super().__init__(subset=osp.splitext(osp.basename(path))[0])
if subset == DEFAULT_SUBSET_NAME:
subset = None
self._subset = subset
self._categories = self._load_categories(self._dataset_dir) self._categories = self._load_categories(self._dataset_dir)
log.debug("Loaded labels: %s", ', '.join("'%s'" % l.name log.debug("Loaded labels: %s", ', '.join("'%s'" % l.name
@ -48,16 +42,6 @@ class _VocExtractor(SourceExtractor):
def __len__(self): def __len__(self):
return len(self._items) return len(self._items)
def subsets(self):
if self._subset:
return [self._subset]
return None
def get_subset(self, name):
if name != self._subset:
return None
return self
def _get_label_id(self, label): def _get_label_id(self, label):
label_id, _ = self._categories[AnnotationType.label].find(label) label_id, _ = self._categories[AnnotationType.label].find(label)
assert label_id is not None, label assert label_id is not None, label

Loading…
Cancel
Save