[Datumaro] Dataset annotations filter (#1053)

* Fix deprecation message

* Update launcher interface

* Add dataset entity, anno filter, remove filter from project config, update transform

* Update project and source cli

* Fix help message

* Refactor tests
main
zhiltsov-max 6 years ago committed by Nikita Manovich
parent e0bcc4652b
commit e5b4c194bb

@ -11,6 +11,7 @@ import shutil
from datumaro.components.project import Project from datumaro.components.project import Project
from datumaro.components.comparator import Comparator from datumaro.components.comparator import Comparator
from datumaro.components.dataset_filter import DatasetItemEncoder
from .diff import DiffVisualizer from .diff import DiffVisualizer
from ..util.project import make_project_path, load_project from ..util.project import make_project_path, load_project
@ -131,7 +132,12 @@ def build_export_parser(parser):
"'/item[image/width < image/height]'; " "'/item[image/width < image/height]'; "
"extract images with large-area bboxes: " "extract images with large-area bboxes: "
"'/item[annotation/type=\"bbox\" and annotation/area>2000]'" "'/item[annotation/type=\"bbox\" and annotation/area>2000]'"
"filter out irrelevant annotations from items: "
"'/item/annotation[label = \"person\"]'"
) )
parser.add_argument('-a', '--filter-annotations', action='store_true',
help="Filter annotations instead of dataset "
"items (default: %(default)s)")
parser.add_argument('-d', '--dest', dest='dst_dir', required=True, parser.add_argument('-d', '--dest', dest='dst_dir', required=True,
help="Directory to save output") help="Directory to save output")
parser.add_argument('-f', '--output-format', required=True, parser.add_argument('-f', '--output-format', required=True,
@ -158,10 +164,11 @@ def export_command(args):
dataset = project.make_dataset() dataset = project.make_dataset()
log.info("Exporting the project...") log.info("Exporting the project...")
dataset.export( dataset.export_project(
save_dir=dst_dir, save_dir=dst_dir,
output_format=args.output_format, output_format=args.output_format,
filter_expr=args.filter, filter_expr=args.filter,
filter_annotations=args.filter_annotations,
cmdline_args=args.extra_args) cmdline_args=args.extra_args)
log.info("Project exported to '%s' as '%s'" % \ log.info("Project exported to '%s' as '%s'" % \
(dst_dir, args.output_format)) (dst_dir, args.output_format))
@ -177,12 +184,21 @@ def build_docs_parser(parser):
def build_extract_parser(parser): def build_extract_parser(parser):
parser.add_argument('-e', '--filter', default=None, parser.add_argument('-e', '--filter', default=None,
help="Filter expression for dataset items. Examples: " help="XML XPath filter expression for dataset items. Examples: "
"extract images with width < height: " "extract images with width < height: "
"'/item[image/width < image/height]'; " "'/item[image/width < image/height]'; "
"extract images with large-area bboxes: " "extract images with large-area bboxes: "
"'/item[annotation/type=\"bbox\" and annotation/area>2000]'" "'/item[annotation/type=\"bbox\" and annotation/area>2000]' "
"filter out irrelevant annotations from items: "
"'/item/annotation[label = \"person\"]'"
) )
parser.add_argument('-a', '--filter-annotations', action='store_true',
help="Filter annotations instead of dataset "
"items (default: %(default)s)")
parser.add_argument('--remove-empty', action='store_true',
help="Remove an item if there are no annotations left after filtration")
parser.add_argument('--dry-run', action='store_true',
help="Print XML representations to be filtered and exit")
parser.add_argument('-d', '--dest', dest='dst_dir', required=True, parser.add_argument('-d', '--dest', dest='dst_dir', required=True,
help="Output directory") help="Output directory")
parser.add_argument('-p', '--project', dest='project_dir', default='.', parser.add_argument('-p', '--project', dest='project_dir', default='.',
@ -193,9 +209,27 @@ def extract_command(args):
project = load_project(args.project_dir) project = load_project(args.project_dir)
dst_dir = osp.abspath(args.dst_dir) dst_dir = osp.abspath(args.dst_dir)
os.makedirs(dst_dir, exist_ok=False) if not args.dry_run:
os.makedirs(dst_dir, exist_ok=False)
dataset = project.make_dataset()
kwargs = {}
if args.filter_annotations:
kwargs['remove_empty'] = args.remove_empty
if args.dry_run:
dataset = dataset.extract(filter_expr=args.filter,
filter_annotations=args.filter_annotations, **kwargs)
for item in dataset:
encoded_item = DatasetItemEncoder.encode(item, dataset.categories())
xml_item = DatasetItemEncoder.to_string(encoded_item)
print(xml_item)
return 0
dataset.extract_project(save_dir=dst_dir, filter_expr=args.filter,
filter_annotations=args.filter_annotations, **kwargs)
project.make_dataset().extract(filter_expr=args.filter, save_dir=dst_dir)
log.info("Subproject extracted to '%s'" % (dst_dir)) log.info("Subproject extracted to '%s'" % (dst_dir))
return 0 return 0
@ -279,7 +313,7 @@ def transform_command(args):
dst_dir = osp.abspath(args.dst_dir) dst_dir = osp.abspath(args.dst_dir)
os.makedirs(dst_dir, exist_ok=False) os.makedirs(dst_dir, exist_ok=False)
project.make_dataset().transform( project.make_dataset().apply_model(
save_dir=dst_dir, save_dir=dst_dir,
model_name=args.model_name) model_name=args.model_name)

@ -188,6 +188,9 @@ def build_export_parser(parser):
"extract images with large-area bboxes: " "extract images with large-area bboxes: "
"'/item[annotation/type=\"bbox\" and annotation/area>2000]'" "'/item[annotation/type=\"bbox\" and annotation/area>2000]'"
) )
parser.add_argument('-a', '--filter-annotations', action='store_true',
help="Filter annotations instead of dataset "
"items (default: %(default)s)")
parser.add_argument('-d', '--dest', dest='dst_dir', required=True, parser.add_argument('-d', '--dest', dest='dst_dir', required=True,
help="Directory to save output") help="Directory to save output")
parser.add_argument('-f', '--output-format', required=True, parser.add_argument('-f', '--output-format', required=True,
@ -215,10 +218,11 @@ def export_command(args):
dataset = source_project.make_dataset() dataset = source_project.make_dataset()
log.info("Exporting the project...") log.info("Exporting the project...")
dataset.export( dataset.export_project(
save_dir=dst_dir, save_dir=dst_dir,
output_format=args.output_format, output_format=args.output_format,
filter_expr=args.filter, filter_expr=args.filter,
filter_annotations=args.filter_annotations,
cmdline_args=args.extra_args) cmdline_args=args.extra_args)
log.info("Source '%s' exported to '%s' as '%s'" % \ log.info("Source '%s' exported to '%s' as '%s'" % \
(args.name, dst_dir, args.output_format)) (args.name, dst_dir, args.output_format))

@ -60,7 +60,6 @@ PROJECT_SCHEMA = _SchemaBuilder() \
.add('subsets', list) \ .add('subsets', list) \
.add('sources', lambda: _DefaultConfig( .add('sources', lambda: _DefaultConfig(
lambda v=None: Source(v))) \ lambda v=None: Source(v))) \
.add('filter', str) \
\ \
.add('project_filename', str, internal=True) \ .add('project_filename', str, internal=True) \
.add('project_dir', str, internal=True) \ .add('project_dir', str, internal=True) \

@ -462,7 +462,7 @@ class _Converter:
void_labels = [src_label for src_id, src_label in source_labels.items() void_labels = [src_label for src_id, src_label in source_labels.items()
if src_label not in target_labels] if src_label not in target_labels]
if void_labels: if void_labels:
log.warn("The following labels are remapped to background: %s" % log.warning("The following labels are remapped to background: %s" %
', '.join(void_labels)) ', '.join(void_labels))
def map_id(src_id): def map_id(src_id):

@ -4,38 +4,27 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from lxml import etree as ET # NOTE: lxml has proper XPath implementation from lxml import etree as ET # NOTE: lxml has proper XPath implementation
from datumaro.components.extractor import (DatasetItem, Annotation, from datumaro.components.extractor import (DatasetItem, Extractor,
Annotation, AnnotationType,
LabelObject, MaskObject, PointsObject, PolygonObject, LabelObject, MaskObject, PointsObject, PolygonObject,
PolyLineObject, BboxObject, CaptionObject, PolyLineObject, BboxObject, CaptionObject,
) )
def _cast(value, type_conv, default=None):
if value is None:
return default
try:
return type_conv(value)
except Exception:
return default
class DatasetItemEncoder: class DatasetItemEncoder:
def encode_item(self, item): @classmethod
def encode(cls, item, categories=None):
item_elem = ET.Element('item') item_elem = ET.Element('item')
ET.SubElement(item_elem, 'id').text = str(item.id) ET.SubElement(item_elem, 'id').text = str(item.id)
ET.SubElement(item_elem, 'subset').text = str(item.subset) ET.SubElement(item_elem, 'subset').text = str(item.subset)
ET.SubElement(item_elem, 'path').text = str('/'.join(item.path))
# Dataset wrapper-specific
ET.SubElement(item_elem, 'source').text = \
str(getattr(item, 'source', None))
ET.SubElement(item_elem, 'extractor').text = \
str(getattr(item, 'extractor', None))
image = item.image image = item.image
if image is not None: if image is not None:
item_elem.append(self.encode_image(image)) item_elem.append(cls.encode_image(image))
for ann in item.annotations: for ann in item.annotations:
item_elem.append(self.encode_object(ann)) item_elem.append(cls.encode_annotation(ann, categories))
return item_elem return item_elem
@ -52,7 +41,7 @@ class DatasetItemEncoder:
return image_elem return image_elem
@classmethod @classmethod
def encode_annotation(cls, annotation): def encode_annotation_base(cls, annotation):
assert isinstance(annotation, Annotation) assert isinstance(annotation, Annotation)
ann_elem = ET.Element('annotation') ann_elem = ET.Element('annotation')
ET.SubElement(ann_elem, 'id').text = str(annotation.id) ET.SubElement(ann_elem, 'id').text = str(annotation.id)
@ -65,18 +54,31 @@ class DatasetItemEncoder:
return ann_elem return ann_elem
@staticmethod
def _get_label(label_id, categories):
label = ''
if categories is not None:
label_cat = categories.get(AnnotationType.label)
if label_cat is not None:
label = label_cat.items[label_id].name
return label
@classmethod @classmethod
def encode_label_object(cls, obj): def encode_label_object(cls, obj, categories):
ann_elem = cls.encode_annotation(obj) ann_elem = cls.encode_annotation_base(obj)
ET.SubElement(ann_elem, 'label').text = \
str(cls._get_label(obj.label, categories))
ET.SubElement(ann_elem, 'label_id').text = str(obj.label) ET.SubElement(ann_elem, 'label_id').text = str(obj.label)
return ann_elem return ann_elem
@classmethod @classmethod
def encode_mask_object(cls, obj): def encode_mask_object(cls, obj, categories):
ann_elem = cls.encode_annotation(obj) ann_elem = cls.encode_annotation_base(obj)
ET.SubElement(ann_elem, 'label').text = \
str(cls._get_label(obj.label, categories))
ET.SubElement(ann_elem, 'label_id').text = str(obj.label) ET.SubElement(ann_elem, 'label_id').text = str(obj.label)
mask = obj.image mask = obj.image
@ -86,9 +88,11 @@ class DatasetItemEncoder:
return ann_elem return ann_elem
@classmethod @classmethod
def encode_bbox_object(cls, obj): def encode_bbox_object(cls, obj, categories):
ann_elem = cls.encode_annotation(obj) ann_elem = cls.encode_annotation_base(obj)
ET.SubElement(ann_elem, 'label').text = \
str(cls._get_label(obj.label, categories))
ET.SubElement(ann_elem, 'label_id').text = str(obj.label) ET.SubElement(ann_elem, 'label_id').text = str(obj.label)
ET.SubElement(ann_elem, 'x').text = str(obj.x) ET.SubElement(ann_elem, 'x').text = str(obj.x)
ET.SubElement(ann_elem, 'y').text = str(obj.y) ET.SubElement(ann_elem, 'y').text = str(obj.y)
@ -99,9 +103,11 @@ class DatasetItemEncoder:
return ann_elem return ann_elem
@classmethod @classmethod
def encode_points_object(cls, obj): def encode_points_object(cls, obj, categories):
ann_elem = cls.encode_annotation(obj) ann_elem = cls.encode_annotation_base(obj)
ET.SubElement(ann_elem, 'label').text = \
str(cls._get_label(obj.label, categories))
ET.SubElement(ann_elem, 'label_id').text = str(obj.label) ET.SubElement(ann_elem, 'label_id').text = str(obj.label)
x, y, w, h = obj.get_bbox() x, y, w, h = obj.get_bbox()
@ -113,20 +119,22 @@ class DatasetItemEncoder:
ET.SubElement(bbox_elem, 'h').text = str(h) ET.SubElement(bbox_elem, 'h').text = str(h)
ET.SubElement(bbox_elem, 'area').text = str(area) ET.SubElement(bbox_elem, 'area').text = str(area)
points = ann_elem.points points = obj.points
for i in range(0, len(points), 2): for i in range(0, len(points), 2):
point_elem = ET.SubElement(ann_elem, 'point') point_elem = ET.SubElement(ann_elem, 'point')
ET.SubElement(point_elem, 'x').text = str(points[i * 2]) ET.SubElement(point_elem, 'x').text = str(points[i])
ET.SubElement(point_elem, 'y').text = str(points[i * 2 + 1]) ET.SubElement(point_elem, 'y').text = str(points[i + 1])
ET.SubElement(point_elem, 'visible').text = \ ET.SubElement(point_elem, 'visible').text = \
str(ann_elem.visibility[i // 2].name) str(obj.visibility[i // 2].name)
return ann_elem return ann_elem
@classmethod @classmethod
def encode_polyline_object(cls, obj): def encode_polygon_object(cls, obj, categories):
ann_elem = cls.encode_annotation(obj) ann_elem = cls.encode_annotation_base(obj)
ET.SubElement(ann_elem, 'label').text = \
str(cls._get_label(obj.label, categories))
ET.SubElement(ann_elem, 'label_id').text = str(obj.label) ET.SubElement(ann_elem, 'label_id').text = str(obj.label)
x, y, w, h = obj.get_bbox() x, y, w, h = obj.get_bbox()
@ -138,57 +146,142 @@ class DatasetItemEncoder:
ET.SubElement(bbox_elem, 'h').text = str(h) ET.SubElement(bbox_elem, 'h').text = str(h)
ET.SubElement(bbox_elem, 'area').text = str(area) ET.SubElement(bbox_elem, 'area').text = str(area)
points = ann_elem.points points = obj.points
for i in range(0, len(points), 2): for i in range(0, len(points), 2):
point_elem = ET.SubElement(ann_elem, 'point') point_elem = ET.SubElement(ann_elem, 'point')
ET.SubElement(point_elem, 'x').text = str(points[i * 2]) ET.SubElement(point_elem, 'x').text = str(points[i])
ET.SubElement(point_elem, 'y').text = str(points[i * 2 + 1]) ET.SubElement(point_elem, 'y').text = str(points[i + 1])
return ann_elem
@classmethod
def encode_polyline_object(cls, obj, categories):
ann_elem = cls.encode_annotation_base(obj)
ET.SubElement(ann_elem, 'label').text = \
str(cls._get_label(obj.label, categories))
ET.SubElement(ann_elem, 'label_id').text = str(obj.label)
x, y, w, h = obj.get_bbox()
area = w * h
bbox_elem = ET.SubElement(ann_elem, 'bbox')
ET.SubElement(bbox_elem, 'x').text = str(x)
ET.SubElement(bbox_elem, 'y').text = str(y)
ET.SubElement(bbox_elem, 'w').text = str(w)
ET.SubElement(bbox_elem, 'h').text = str(h)
ET.SubElement(bbox_elem, 'area').text = str(area)
points = obj.points
for i in range(0, len(points), 2):
point_elem = ET.SubElement(ann_elem, 'point')
ET.SubElement(point_elem, 'x').text = str(points[i])
ET.SubElement(point_elem, 'y').text = str(points[i + 1])
return ann_elem return ann_elem
@classmethod @classmethod
def encode_caption_object(cls, obj): def encode_caption_object(cls, obj):
ann_elem = cls.encode_annotation(obj) ann_elem = cls.encode_annotation_base(obj)
ET.SubElement(ann_elem, 'caption').text = str(obj.caption) ET.SubElement(ann_elem, 'caption').text = str(obj.caption)
return ann_elem return ann_elem
def encode_object(self, o): @classmethod
def encode_annotation(cls, o, categories=None):
if isinstance(o, LabelObject): if isinstance(o, LabelObject):
return self.encode_label_object(o) return cls.encode_label_object(o, categories)
if isinstance(o, MaskObject): if isinstance(o, MaskObject):
return self.encode_mask_object(o) return cls.encode_mask_object(o, categories)
if isinstance(o, BboxObject): if isinstance(o, BboxObject):
return self.encode_bbox_object(o) return cls.encode_bbox_object(o, categories)
if isinstance(o, PointsObject): if isinstance(o, PointsObject):
return self.encode_points_object(o) return cls.encode_points_object(o, categories)
if isinstance(o, PolyLineObject): if isinstance(o, PolyLineObject):
return self.encode_polyline_object(o) return cls.encode_polyline_object(o, categories)
if isinstance(o, PolygonObject): if isinstance(o, PolygonObject):
return self.encode_polygon_object(o) return cls.encode_polygon_object(o, categories)
if isinstance(o, CaptionObject): if isinstance(o, CaptionObject):
return self.encode_caption_object(o) return cls.encode_caption_object(o)
if isinstance(o, Annotation): # keep after derived classes raise NotImplementedError("Unexpected annotation object passed: %s" % o)
return self.encode_annotation(o)
@staticmethod
def to_string(encoded_item):
return ET.tostring(encoded_item, encoding='unicode', pretty_print=True)
def XPathDatasetFilter(extractor, xpath=None):
if xpath is None:
return extractor
xpath = ET.XPath(xpath)
f = lambda item: bool(xpath(
DatasetItemEncoder.encode(item, extractor.categories())))
return extractor.select(f)
class XPathAnnotationsFilter(Extractor): # NOTE: essentially, a transform
class ItemWrapper(DatasetItem):
def __init__(self, item, annotations):
self._item = item
self._annotations = annotations
@DatasetItem.id.getter
def id(self):
return self._item.id
@DatasetItem.subset.getter
def subset(self):
return self._item.subset
if isinstance(o, DatasetItem): @DatasetItem.path.getter
return self.encode_item(o) def path(self):
return self._item.path
return None @DatasetItem.annotations.getter
def annotations(self):
return self._annotations
class XPathDatasetFilter: @DatasetItem.has_image.getter
def __init__(self, filter_text=None): def has_image(self):
self._filter = None return self._item.has_image
if filter_text is not None:
self._filter = ET.XPath(filter_text)
self._encoder = DatasetItemEncoder()
def __call__(self, item): @DatasetItem.image.getter
encoded_item = self._serialize_item(item) def image(self):
return self._item.image
def __init__(self, extractor, xpath=None, remove_empty=False):
super().__init__()
self._extractor = extractor
if xpath is not None:
xpath = ET.XPath(xpath)
self._filter = xpath
self._remove_empty = remove_empty
def __len__(self):
return len(self._extractor)
def __iter__(self):
for item in self._extractor:
item = self._filter_item(item)
if item is not None:
yield item
def subsets(self):
return self._extractor.subsets()
def categories(self):
return self._extractor.categories()
def _filter_item(self, item):
if self._filter is None: if self._filter is None:
return True return item
return bool(self._filter(encoded_item)) encoded = DatasetItemEncoder.encode(item, self._extractor.categories())
filtered = self._filter(encoded)
filtered = [elem for elem in filtered if elem.tag == 'annotation']
encoded = encoded.findall('annotation')
annotations = [item.annotations[encoded.index(e)] for e in filtered]
def _serialize_item(self, item): if self._remove_empty and len(annotations) == 0:
return self._encoder.encode_item(item) return None
return self.ItemWrapper(item, annotations)

@ -580,9 +580,9 @@ class _DatasetFilter:
return filter(self.predicate, self.iterable) return filter(self.predicate, self.iterable)
class _ExtractorBase(IExtractor): class _ExtractorBase(IExtractor):
def __init__(self, length=None): def __init__(self, length=None, subsets=None):
self._length = length self._length = length
self._subsets = None self._subsets = subsets
def _init_cache(self): def _init_cache(self):
subsets = set() subsets = set()
@ -612,9 +612,12 @@ class _ExtractorBase(IExtractor):
else: else:
raise Exception("Unknown subset '%s' requested" % name) raise Exception("Unknown subset '%s' requested" % name)
def transform(self, method, *args, **kwargs):
return method(self, *args, **kwargs)
class DatasetIteratorWrapper(_ExtractorBase): class DatasetIteratorWrapper(_ExtractorBase):
def __init__(self, iterable, categories): def __init__(self, iterable, categories, subsets=None):
super().__init__(length=None) super().__init__(length=None, subsets=subsets)
self._iterable = iterable self._iterable = iterable
self._categories = categories self._categories = categories
@ -626,7 +629,7 @@ class DatasetIteratorWrapper(_ExtractorBase):
def select(self, pred): def select(self, pred):
return DatasetIteratorWrapper( return DatasetIteratorWrapper(
_DatasetFilter(self, pred), self.categories()) _DatasetFilter(self, pred), self.categories(), self.subsets())
class Extractor(_ExtractorBase): class Extractor(_ExtractorBase):
def __init__(self, length=None): def __init__(self, length=None):
@ -637,7 +640,7 @@ class Extractor(_ExtractorBase):
def select(self, pred): def select(self, pred):
return DatasetIteratorWrapper( return DatasetIteratorWrapper(
_DatasetFilter(self, pred), self.categories()) _DatasetFilter(self, pred), self.categories(), self.subsets())
DEFAULT_SUBSET_NAME = 'default' DEFAULT_SUBSET_NAME = 'default'

@ -10,7 +10,7 @@ from datumaro.components.extractor import DatasetItem, Extractor
# pylint: disable=no-self-use # pylint: disable=no-self-use
class Launcher: class Launcher:
def __init__(self): def __init__(self, model_dir=None):
pass pass
def launch(self, inputs): def launch(self, inputs):

@ -14,9 +14,10 @@ import sys
from datumaro.components.config import Config, DEFAULT_FORMAT from datumaro.components.config import Config, DEFAULT_FORMAT
from datumaro.components.config_model import * from datumaro.components.config_model import *
from datumaro.components.extractor import * from datumaro.components.extractor import DatasetItem, Extractor
from datumaro.components.launcher import * from datumaro.components.launcher import InferenceWrapper
from datumaro.components.dataset_filter import XPathDatasetFilter from datumaro.components.dataset_filter import \
XPathDatasetFilter, XPathAnnotationsFilter
def import_foreign_module(name, path): def import_foreign_module(name, path):
@ -305,7 +306,131 @@ class DatasetItemWrapper(DatasetItem):
return self._image return self._image
return self._item.image return self._item.image
class ProjectDataset(Extractor): class Dataset(Extractor):
@classmethod
def from_extractors(cls, *sources):
# merge categories
# TODO: implement properly with merging and annotations remapping
categories = {}
for source in sources:
categories.update(source.categories())
for source in sources:
for cat_type, source_cat in source.categories().items():
assert categories[cat_type] == source_cat
dataset = Dataset(categories=categories)
# merge items
subsets = defaultdict(lambda: Subset(dataset))
for source in sources:
for item in source:
path = None # NOTE: merge everything into our own dataset
existing_item = subsets[item.subset].items.get(item.id)
if existing_item is not None:
image = None
if existing_item.has_image:
# TODO: think of image comparison
image = cls._lazy_image(existing_item)
item = DatasetItemWrapper(item=item, path=path,
image=image, annotations=self._merge_anno(
existing_item.annotations, item.annotations))
else:
item = DatasetItemWrapper(item=item, path=path,
annotations=item.annotations)
subsets[item.subset].items[item.id] = item
self._subsets = dict(subsets)
def __init__(self, categories=None):
super().__init__()
self._subsets = {}
if not categories:
categories = {}
self._categories = categories
def __iter__(self):
for subset in self._subsets.values():
for item in subset:
yield item
def __len__(self):
if self._length is None:
self._length = reduce(lambda s, x: s + len(x),
self._subsets.values(), 0)
return self._length
def get_subset(self, name):
return self._subsets[name]
def subsets(self):
return list(self._subsets)
def categories(self):
return self._categories
def get(self, item_id, subset=None, path=None):
if path:
raise KeyError("Requested dataset item path is not found")
return self._subsets[subset].items[item_id]
def put(self, item, item_id=None, subset=None, path=None):
if path:
raise KeyError("Requested dataset item path is not found")
if item_id is None:
item_id = item.id
if subset is None:
subset = item.subset
item = DatasetItemWrapper(item=item, path=None,
annotations=item.annotations)
if item.subset not in self._subsets:
self._subsets[item.subset] = Subset(self)
self._subsets[subset].items[item_id] = item
self._length = None
return item
def extract(self, filter_expr, filter_annotations=False, **kwargs):
if filter_annotations:
return self.transform(XPathAnnotationsFilter, filter_expr, **kwargs)
else:
return self.transform(XPathDatasetFilter, filter_expr, **kwargs)
def update(self, items):
for item in items:
self.put(item)
return self
def define_categories(self, categories):
assert not self._categories
self._categories = categories
@staticmethod
def _lazy_image(item):
# NOTE: avoid https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result
return lambda: item.image
@staticmethod
def _merge_anno(a, b):
from itertools import chain
merged = []
for item in chain(a, b):
found = False
for elem in merged:
if elem == item:
found = True
break
if not found:
merged.append(item)
return merged
class ProjectDataset(Dataset):
def __init__(self, project): def __init__(self, project):
super().__init__() super().__init__()
@ -313,11 +438,6 @@ class ProjectDataset(Extractor):
config = self.config config = self.config
env = self.env env = self.env
dataset_filter = None
if config.filter:
dataset_filter = XPathDatasetFilter(config.filter)
self._filter = dataset_filter
sources = {} sources = {}
for s_name, source in config.sources.items(): for s_name, source in config.sources.items():
s_format = source.format s_format = source.format
@ -335,7 +455,7 @@ class ProjectDataset(Extractor):
own_source = None own_source = None
own_source_dir = osp.join(config.project_dir, config.dataset_dir) own_source_dir = osp.join(config.project_dir, config.dataset_dir)
if osp.isdir(config.project_dir) and osp.isdir(own_source_dir): if config.project_dir and osp.isdir(own_source_dir):
log.disable(log.INFO) log.disable(log.INFO)
own_source = env.make_importer(DEFAULT_FORMAT)(own_source_dir) \ own_source = env.make_importer(DEFAULT_FORMAT)(own_source_dir) \
.make_dataset() .make_dataset()
@ -358,9 +478,6 @@ class ProjectDataset(Extractor):
for source_name, source in self._sources.items(): for source_name, source in self._sources.items():
log.debug("Loading '%s' source contents..." % source_name) log.debug("Loading '%s' source contents..." % source_name)
for item in source: for item in source:
if dataset_filter and not dataset_filter(item):
continue
existing_item = subsets[item.subset].items.get(item.id) existing_item = subsets[item.subset].items.get(item.id)
if existing_item is not None: if existing_item is not None:
image = None image = None
@ -370,14 +487,14 @@ class ProjectDataset(Extractor):
path = existing_item.path path = existing_item.path
if item.path != path: if item.path != path:
path = None path = None # NOTE: move to our own dataset
item = DatasetItemWrapper(item=item, path=path, item = DatasetItemWrapper(item=item, path=path,
image=image, annotations=self._merge_anno( image=image, annotations=self._merge_anno(
existing_item.annotations, item.annotations)) existing_item.annotations, item.annotations))
else: else:
s_config = config.sources[source_name] s_config = config.sources[source_name]
if s_config and \ if s_config and \
s_config.format != self.env.PROJECT_EXTRACTOR_NAME: s_config.format != env.PROJECT_EXTRACTOR_NAME:
# NOTE: consider imported sources as our own dataset # NOTE: consider imported sources as our own dataset
path = None path = None
else: else:
@ -394,9 +511,6 @@ class ProjectDataset(Extractor):
if own_source is not None: if own_source is not None:
log.debug("Loading own dataset...") log.debug("Loading own dataset...")
for item in own_source: for item in own_source:
if dataset_filter and not dataset_filter(item):
continue
if not item.has_image: if not item.has_image:
existing_item = subsets[item.subset].items.get(item.id) existing_item = subsets[item.subset].items.get(item.id)
if existing_item is not None: if existing_item is not None:
@ -417,55 +531,9 @@ class ProjectDataset(Extractor):
self._length = None self._length = None
@staticmethod
def _lazy_image(item):
# NOTE: avoid https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result
return lambda: item.image
@staticmethod
def _merge_anno(a, b):
from itertools import chain
merged = []
for item in chain(a, b):
found = False
for elem in merged:
if elem == item:
found = True
break
if not found:
merged.append(item)
return merged
def iterate_own(self): def iterate_own(self):
return self.select(lambda item: not item.path) return self.select(lambda item: not item.path)
def __iter__(self):
for subset in self._subsets.values():
for item in subset:
if self._filter and not self._filter(item):
continue
yield item
def __len__(self):
if self._length is None:
self._length = reduce(lambda s, x: s + len(x),
self._subsets.values(), 0)
return self._length
def get_subset(self, name):
return self._subsets[name]
def subsets(self):
return list(self._subsets)
def categories(self):
return self._categories
def define_categories(self, categories):
assert not self._categories
self._categories = categories
def get(self, item_id, subset=None, path=None): def get(self, item_id, subset=None, path=None):
if path: if path:
source = path[0] source = path[0]
@ -498,54 +566,6 @@ class ProjectDataset(Extractor):
return item return item
def build(self, tasks=None):
pass
def docs(self):
pass
def transform(self, model_name, save_dir=None):
project = Project(self.config)
project.config.remove('sources')
if save_dir is None:
save_dir = self.config.project_dir
project.config.project_dir = save_dir
dataset = project.make_dataset()
launcher = self._project.make_executable_model(model_name)
inference = InferenceWrapper(self, launcher)
dataset.update(inference)
dataset.save(merge=True)
def export(self, save_dir, output_format,
filter_expr=None, **converter_kwargs):
save_dir = osp.abspath(save_dir)
os.makedirs(save_dir, exist_ok=True)
dataset = self
if filter_expr:
dataset_filter = XPathDatasetFilter(filter_expr)
dataset = dataset.select(dataset_filter)
converter = self.env.make_converter(output_format, **converter_kwargs)
converter(dataset, save_dir)
def extract(self, save_dir, filter_expr=None):
project = Project(self.config)
if filter_expr:
XPathDatasetFilter(filter_expr)
project.set_filter(filter_expr)
project.save(save_dir)
def update(self, items):
for item in items:
if self._filter and not self._filter(item):
continue
self.put(item)
return self
def save(self, save_dir=None, merge=False, recursive=True, def save(self, save_dir=None, merge=False, recursive=True,
save_images=False): save_images=False):
if save_dir is None: if save_dir is None:
@ -600,6 +620,60 @@ class ProjectDataset(Extractor):
def sources(self): def sources(self):
return self._sources return self._sources
def _save_branch_project(self, extractor, save_dir=None):
# NOTE: probably this function should be in the ViewModel layer
save_dir = osp.abspath(save_dir)
if save_dir:
dst_project = Project()
else:
if not self.config.project_dir:
raise Exception("Either a save directory or a project "
"directory should be specified")
save_dir = self.config.project_dir
dst_project = Project(Config(self.config))
dst_project.config.remove('project_dir')
dst_project.config.remove('sources')
dst_dataset = dst_project.make_dataset()
dst_dataset.define_categories(extractor.categories())
dst_dataset.update(extractor)
dst_dataset.save(save_dir=save_dir, merge=True)
def transform_project(self, method, *args, save_dir=None, **kwargs):
# NOTE: probably this function should be in the ViewModel layer
transformed = self.transform(method, *args, **kwargs)
self._save_branch_project(transformed, save_dir=save_dir)
def apply_model(self, model_name, save_dir=None):
# NOTE: probably this function should be in the ViewModel layer
launcher = self._project.make_executable_model(model_name)
self.transform_project(InferenceWrapper, launcher, save_dir=save_dir)
def export_project(self, save_dir, output_format,
filter_expr=None, filter_annotations=False, **converter_kwargs):
# NOTE: probably this function should be in the ViewModel layer
save_dir = osp.abspath(save_dir)
os.makedirs(save_dir, exist_ok=True)
dataset = self
if filter_expr:
dataset = dataset.extract(filter_expr, filter_annotations)
converter = self.env.make_converter(output_format, **converter_kwargs)
converter(dataset, save_dir)
def extract_project(self, filter_expr, filter_annotations=False,
save_dir=None, remove_empty=False):
# NOTE: probably this function should be in the ViewModel layer
filtered = self
if filter_expr:
filtered = self.extract(filter_expr,
filter_annotations=filter_annotations,
remove_empty=remove_empty)
self._save_branch_project(filtered, save_dir=save_dir)
class Project: class Project:
@staticmethod @staticmethod
def load(path): def load(path):
@ -697,24 +771,10 @@ class Project:
config = Config(self.config) config = Config(self.config)
config.remove('sources') config.remove('sources')
config.remove('subsets') config.remove('subsets')
config.remove('filter')
project = Project(config) project = Project(config)
project.add_source(name, source) project.add_source(name, source)
return project return project
def get_filter(self):
if 'filter' in self.config:
return self.config.filter
return ''
def set_filter(self, value=None):
if not value:
self.config.remove('filter')
else:
# check filter
XPathDatasetFilter(value)
self.config.filter = value
def local_model_dir(self, model_name): def local_model_dir(self, model_name):
return osp.join( return osp.join(
self.config.env_dir, self.env.config.models_dir, model_name) self.config.env_dir, self.env.config.models_dir, model_name)
@ -726,4 +786,4 @@ class Project:
def load_project_as_dataset(url): def load_project_as_dataset(url):
# implement the function declared above # implement the function declared above
return Project.load(url).make_dataset() return Project.load(url).make_dataset()
# pylint: enable=function-redefined # pylint: enable=function-redefined

@ -1,3 +1,4 @@
import numpy as np
import os import os
import os.path as osp import os.path as osp
@ -7,9 +8,13 @@ from datumaro.components.project import Project, Environment
from datumaro.components.project import Source, Model from datumaro.components.project import Source, Model
from datumaro.components.launcher import Launcher, InferenceWrapper from datumaro.components.launcher import Launcher, InferenceWrapper
from datumaro.components.converter import Converter from datumaro.components.converter import Converter
from datumaro.components.extractor import Extractor, DatasetItem, LabelObject from datumaro.components.extractor import (Extractor, DatasetItem,
LabelObject, MaskObject, PointsObject, PolygonObject,
PolyLineObject, BboxObject, CaptionObject,
)
from datumaro.components.config import Config, DefaultConfig, SchemaBuilder from datumaro.components.config import Config, DefaultConfig, SchemaBuilder
from datumaro.components.dataset_filter import XPathDatasetFilter from datumaro.components.dataset_filter import \
XPathDatasetFilter, XPathAnnotationsFilter, DatasetItemEncoder
from datumaro.util.test_utils import TestDir from datumaro.util.test_utils import TestDir
@ -129,18 +134,11 @@ class ProjectTest(TestCase):
def test_can_batch_launch_custom_model(self): def test_can_batch_launch_custom_model(self):
class TestExtractor(Extractor): class TestExtractor(Extractor):
def __init__(self, url, n=0):
super().__init__(length=n)
self.n = n
def __iter__(self): def __iter__(self):
for i in range(self.n): for i in range(5):
yield DatasetItem(id=i, subset='train', image=i) yield DatasetItem(id=i, subset='train', image=i)
class TestLauncher(Launcher): class TestLauncher(Launcher):
def __init__(self, **kwargs):
pass
def launch(self, inputs): def launch(self, inputs):
for i, inp in enumerate(inputs): for i, inp in enumerate(inputs):
yield [ LabelObject(attributes={'idx': i, 'data': inp}) ] yield [ LabelObject(attributes={'idx': i, 'data': inp}) ]
@ -152,7 +150,7 @@ class ProjectTest(TestCase):
project.env.launchers.register(launcher_name, TestLauncher) project.env.launchers.register(launcher_name, TestLauncher)
project.add_model(model_name, { 'launcher': launcher_name }) project.add_model(model_name, { 'launcher': launcher_name })
model = project.make_executable_model(model_name) model = project.make_executable_model(model_name)
extractor = TestExtractor('', n=5) extractor = TestExtractor()
batch_size = 3 batch_size = 3
executor = InferenceWrapper(extractor, model, batch_size=batch_size) executor = InferenceWrapper(extractor, model, batch_size=batch_size)
@ -166,19 +164,12 @@ class ProjectTest(TestCase):
def test_can_do_transform_with_custom_model(self): def test_can_do_transform_with_custom_model(self):
class TestExtractorSrc(Extractor): class TestExtractorSrc(Extractor):
def __init__(self, url, n=2):
super().__init__(length=n)
self.n = n
def __iter__(self): def __iter__(self):
for i in range(self.n): for i in range(2):
yield DatasetItem(id=i, subset='train', image=i, yield DatasetItem(id=i, subset='train', image=i,
annotations=[ LabelObject(i) ]) annotations=[ LabelObject(i) ])
class TestLauncher(Launcher): class TestLauncher(Launcher):
def __init__(self, **kwargs):
pass
def launch(self, inputs): def launch(self, inputs):
for inp in inputs: for inp in inputs:
yield [ LabelObject(inp) ] yield [ LabelObject(inp) ]
@ -186,7 +177,7 @@ class ProjectTest(TestCase):
class TestConverter(Converter): class TestConverter(Converter):
def __call__(self, extractor, save_dir): def __call__(self, extractor, save_dir):
for item in extractor: for item in extractor:
with open(osp.join(save_dir, '%s.txt' % item.id), 'w+') as f: with open(osp.join(save_dir, '%s.txt' % item.id), 'w') as f:
f.write(str(item.subset) + '\n') f.write(str(item.subset) + '\n')
f.write(str(item.annotations[0].label) + '\n') f.write(str(item.annotations[0].label) + '\n')
@ -199,8 +190,8 @@ class ProjectTest(TestCase):
for path in self.items: for path in self.items:
with open(path, 'r') as f: with open(path, 'r') as f:
index = osp.splitext(osp.basename(path))[0] index = osp.splitext(osp.basename(path))[0]
subset = f.readline()[:-1] subset = f.readline().strip()
label = int(f.readline()[:-1]) label = int(f.readline().strip())
assert subset == 'train' assert subset == 'train'
yield DatasetItem(id=index, subset=subset, yield DatasetItem(id=index, subset=subset,
annotations=[ LabelObject(label) ]) annotations=[ LabelObject(label) ])
@ -217,7 +208,8 @@ class ProjectTest(TestCase):
project.add_source('source', { 'format': extractor_name }) project.add_source('source', { 'format': extractor_name })
with TestDir() as test_dir: with TestDir() as test_dir:
project.make_dataset().transform(model_name, test_dir.path) project.make_dataset().apply_model(model_name=model_name,
save_dir=test_dir.path)
result = Project.load(test_dir.path) result = Project.load(test_dir.path)
result.env.extractors.register(extractor_name, TestExtractorDst) result.env.extractors.register(extractor_name, TestExtractorDst)
@ -255,21 +247,16 @@ class ProjectTest(TestCase):
def test_project_filter_can_be_applied(self): def test_project_filter_can_be_applied(self):
class TestExtractor(Extractor): class TestExtractor(Extractor):
def __init__(self, url, n=10):
super().__init__(length=n)
self.n = n
def __iter__(self): def __iter__(self):
for i in range(self.n): for i in range(10):
yield DatasetItem(id=i, subset='train') yield DatasetItem(id=i, subset='train')
e_type = 'type' e_type = 'type'
project = Project() project = Project()
project.env.extractors.register(e_type, TestExtractor) project.env.extractors.register(e_type, TestExtractor)
project.add_source('source', { 'format': e_type }) project.add_source('source', { 'format': e_type })
project.set_filter('/item[id < 5]')
dataset = project.make_dataset() dataset = project.make_dataset().extract('/item[id < 5]')
self.assertEqual(5, len(dataset)) self.assertEqual(5, len(dataset))
@ -326,30 +313,23 @@ class ProjectTest(TestCase):
self.assertEqual(1, len(dataset.sources['child2'])) self.assertEqual(1, len(dataset.sources['child2']))
def test_project_can_merge_item_annotations(self): def test_project_can_merge_item_annotations(self):
class TestExtractor(Extractor): class TestExtractor1(Extractor):
def __init__(self, url, v=None):
super().__init__()
self.v = v
def __iter__(self): def __iter__(self):
v1_item = DatasetItem(id=1, subset='train', annotations=[ yield DatasetItem(id=1, subset='train', annotations=[
LabelObject(2, id=3), LabelObject(2, id=3),
LabelObject(3, attributes={ 'x': 1 }), LabelObject(3, attributes={ 'x': 1 }),
]) ])
v2_item = DatasetItem(id=1, subset='train', annotations=[ class TestExtractor2(Extractor):
def __iter__(self):
yield DatasetItem(id=1, subset='train', annotations=[
LabelObject(3, attributes={ 'x': 1 }), LabelObject(3, attributes={ 'x': 1 }),
LabelObject(4, id=4), LabelObject(4, id=4),
]) ])
if self.v == 1:
yield v1_item
else:
yield v2_item
project = Project() project = Project()
project.env.extractors.register('t1', lambda p: TestExtractor(p, v=1)) project.env.extractors.register('t1', TestExtractor1)
project.env.extractors.register('t2', lambda p: TestExtractor(p, v=2)) project.env.extractors.register('t2', TestExtractor2)
project.add_source('source1', { 'format': 't1' }) project.add_source('source1', { 'format': 't1' })
project.add_source('source2', { 'format': 't2' }) project.add_source('source2', { 'format': 't2' })
@ -361,23 +341,103 @@ class ProjectTest(TestCase):
self.assertEqual(3, len(item.annotations)) self.assertEqual(3, len(item.annotations))
class DatasetFilterTest(TestCase): class DatasetFilterTest(TestCase):
class TestExtractor(Extractor): @staticmethod
def __init__(self, url, n=0): def test_item_representations():
super().__init__(length=n) item = DatasetItem(id=1, subset='subset', path=['a', 'b'],
self.n = n image=np.ones((5, 4, 3)),
annotations=[
def __iter__(self): LabelObject(0, attributes={'a1': 1, 'a2': '2'}, id=1, group=2),
for i in range(self.n): CaptionObject('hello', id=1),
yield DatasetItem(id=i, subset='train') CaptionObject('world', group=5),
LabelObject(2, id=3, attributes={ 'x': 1, 'y': '2' }),
BboxObject(1, 2, 3, 4, label=4, id=4, attributes={ 'a': 1.0 }),
BboxObject(5, 6, 7, 8, id=5, group=5),
PointsObject([1, 2, 2, 0, 1, 1], label=0, id=5),
MaskObject(label=3, id=5, image=np.ones((2, 3))),
PolyLineObject([1, 2, 3, 4, 5, 6, 7, 8], id=11),
PolygonObject([1, 2, 3, 4, 5, 6, 7, 8]),
]
)
encoded = DatasetItemEncoder.encode(item)
DatasetItemEncoder.to_string(encoded)
def test_item_filter_can_be_applied(self):
class TestExtractor(Extractor):
def __iter__(self):
for i in range(4):
yield DatasetItem(id=i, subset='train')
def test_xpathfilter_can_be_applied(self): extractor = TestExtractor()
extractor = self.TestExtractor('', n=4)
dataset_filter = XPathDatasetFilter('/item[id > 1]')
filtered = extractor.select(dataset_filter) filtered = XPathDatasetFilter(extractor, '/item[id > 1]')
self.assertEqual(2, len(filtered)) self.assertEqual(2, len(filtered))
def test_annotations_filter_can_be_applied(self):
class SrcTestExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=0),
DatasetItem(id=1, annotations=[
LabelObject(0),
LabelObject(1),
]),
DatasetItem(id=2, annotations=[
LabelObject(0),
LabelObject(2),
]),
])
class DstTestExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=0),
DatasetItem(id=1, annotations=[
LabelObject(0),
]),
DatasetItem(id=2, annotations=[
LabelObject(0),
]),
])
extractor = SrcTestExtractor()
filtered = XPathAnnotationsFilter(extractor,
'/item/annotation[label_id = 0]')
self.assertListEqual(list(filtered), list(DstTestExtractor()))
def test_annotations_filter_can_remove_empty_items(self):
class SrcTestExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=0),
DatasetItem(id=1, annotations=[
LabelObject(0),
LabelObject(1),
]),
DatasetItem(id=2, annotations=[
LabelObject(0),
LabelObject(2),
]),
])
class DstTestExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=2, annotations=[
LabelObject(2),
]),
])
extractor = SrcTestExtractor()
filtered = XPathAnnotationsFilter(extractor,
'/item/annotation[label_id = 2]', remove_empty=True)
self.assertListEqual(list(filtered), list(DstTestExtractor()))
class ConfigTest(TestCase): class ConfigTest(TestCase):
def test_can_produce_multilayer_config_from_dict(self): def test_can_produce_multilayer_config_from_dict(self):
schema_low = SchemaBuilder() \ schema_low = SchemaBuilder() \
@ -409,9 +469,6 @@ class ConfigTest(TestCase):
class ExtractorTest(TestCase): class ExtractorTest(TestCase):
def test_custom_extractor_can_be_created(self): def test_custom_extractor_can_be_created(self):
class CustomExtractor(Extractor): class CustomExtractor(Extractor):
def __init__(self, url):
super().__init__()
def __iter__(self): def __iter__(self):
return iter([ return iter([
DatasetItem(id=0, subset='train'), DatasetItem(id=0, subset='train'),

Loading…
Cancel
Save