[Datumaro] Dataset annotations filter (#1053)

* Fix deprecation message

* Update launcher interface

* Add dataset entity, anno filter, remove filter from project config, update transform

* Update project and source cli

* Fix help message

* Refactor tests
main
zhiltsov-max 6 years ago committed by Nikita Manovich
parent e0bcc4652b
commit e5b4c194bb

@ -11,6 +11,7 @@ import shutil
from datumaro.components.project import Project
from datumaro.components.comparator import Comparator
from datumaro.components.dataset_filter import DatasetItemEncoder
from .diff import DiffVisualizer
from ..util.project import make_project_path, load_project
@ -131,7 +132,12 @@ def build_export_parser(parser):
"'/item[image/width < image/height]'; "
"extract images with large-area bboxes: "
"'/item[annotation/type=\"bbox\" and annotation/area>2000]'"
"filter out irrelevant annotations from items: "
"'/item/annotation[label = \"person\"]'"
)
parser.add_argument('-a', '--filter-annotations', action='store_true',
help="Filter annotations instead of dataset "
"items (default: %(default)s)")
parser.add_argument('-d', '--dest', dest='dst_dir', required=True,
help="Directory to save output")
parser.add_argument('-f', '--output-format', required=True,
@ -158,10 +164,11 @@ def export_command(args):
dataset = project.make_dataset()
log.info("Exporting the project...")
dataset.export(
dataset.export_project(
save_dir=dst_dir,
output_format=args.output_format,
filter_expr=args.filter,
filter_annotations=args.filter_annotations,
cmdline_args=args.extra_args)
log.info("Project exported to '%s' as '%s'" % \
(dst_dir, args.output_format))
@ -177,12 +184,21 @@ def build_docs_parser(parser):
def build_extract_parser(parser):
parser.add_argument('-e', '--filter', default=None,
help="Filter expression for dataset items. Examples: "
help="XML XPath filter expression for dataset items. Examples: "
"extract images with width < height: "
"'/item[image/width < image/height]'; "
"extract images with large-area bboxes: "
"'/item[annotation/type=\"bbox\" and annotation/area>2000]'"
"'/item[annotation/type=\"bbox\" and annotation/area>2000]' "
"filter out irrelevant annotations from items: "
"'/item/annotation[label = \"person\"]'"
)
parser.add_argument('-a', '--filter-annotations', action='store_true',
help="Filter annotations instead of dataset "
"items (default: %(default)s)")
parser.add_argument('--remove-empty', action='store_true',
help="Remove an item if there are no annotations left after filtration")
parser.add_argument('--dry-run', action='store_true',
help="Print XML representations to be filtered and exit")
parser.add_argument('-d', '--dest', dest='dst_dir', required=True,
help="Output directory")
parser.add_argument('-p', '--project', dest='project_dir', default='.',
@ -193,9 +209,27 @@ def extract_command(args):
project = load_project(args.project_dir)
dst_dir = osp.abspath(args.dst_dir)
os.makedirs(dst_dir, exist_ok=False)
if not args.dry_run:
os.makedirs(dst_dir, exist_ok=False)
dataset = project.make_dataset()
kwargs = {}
if args.filter_annotations:
kwargs['remove_empty'] = args.remove_empty
if args.dry_run:
dataset = dataset.extract(filter_expr=args.filter,
filter_annotations=args.filter_annotations, **kwargs)
for item in dataset:
encoded_item = DatasetItemEncoder.encode(item, dataset.categories())
xml_item = DatasetItemEncoder.to_string(encoded_item)
print(xml_item)
return 0
dataset.extract_project(save_dir=dst_dir, filter_expr=args.filter,
filter_annotations=args.filter_annotations, **kwargs)
project.make_dataset().extract(filter_expr=args.filter, save_dir=dst_dir)
log.info("Subproject extracted to '%s'" % (dst_dir))
return 0
@ -279,7 +313,7 @@ def transform_command(args):
dst_dir = osp.abspath(args.dst_dir)
os.makedirs(dst_dir, exist_ok=False)
project.make_dataset().transform(
project.make_dataset().apply_model(
save_dir=dst_dir,
model_name=args.model_name)

@ -188,6 +188,9 @@ def build_export_parser(parser):
"extract images with large-area bboxes: "
"'/item[annotation/type=\"bbox\" and annotation/area>2000]'"
)
parser.add_argument('-a', '--filter-annotations', action='store_true',
help="Filter annotations instead of dataset "
"items (default: %(default)s)")
parser.add_argument('-d', '--dest', dest='dst_dir', required=True,
help="Directory to save output")
parser.add_argument('-f', '--output-format', required=True,
@ -215,10 +218,11 @@ def export_command(args):
dataset = source_project.make_dataset()
log.info("Exporting the project...")
dataset.export(
dataset.export_project(
save_dir=dst_dir,
output_format=args.output_format,
filter_expr=args.filter,
filter_annotations=args.filter_annotations,
cmdline_args=args.extra_args)
log.info("Source '%s' exported to '%s' as '%s'" % \
(args.name, dst_dir, args.output_format))

@ -60,7 +60,6 @@ PROJECT_SCHEMA = _SchemaBuilder() \
.add('subsets', list) \
.add('sources', lambda: _DefaultConfig(
lambda v=None: Source(v))) \
.add('filter', str) \
\
.add('project_filename', str, internal=True) \
.add('project_dir', str, internal=True) \

@ -462,7 +462,7 @@ class _Converter:
void_labels = [src_label for src_id, src_label in source_labels.items()
if src_label not in target_labels]
if void_labels:
log.warn("The following labels are remapped to background: %s" %
log.warning("The following labels are remapped to background: %s" %
', '.join(void_labels))
def map_id(src_id):

@ -4,38 +4,27 @@
# SPDX-License-Identifier: MIT
from lxml import etree as ET # NOTE: lxml has proper XPath implementation
from datumaro.components.extractor import (DatasetItem, Annotation,
from datumaro.components.extractor import (DatasetItem, Extractor,
Annotation, AnnotationType,
LabelObject, MaskObject, PointsObject, PolygonObject,
PolyLineObject, BboxObject, CaptionObject,
)
def _cast(value, type_conv, default=None):
if value is None:
return default
try:
return type_conv(value)
except Exception:
return default
class DatasetItemEncoder:
def encode_item(self, item):
@classmethod
def encode(cls, item, categories=None):
item_elem = ET.Element('item')
ET.SubElement(item_elem, 'id').text = str(item.id)
ET.SubElement(item_elem, 'subset').text = str(item.subset)
# Dataset wrapper-specific
ET.SubElement(item_elem, 'source').text = \
str(getattr(item, 'source', None))
ET.SubElement(item_elem, 'extractor').text = \
str(getattr(item, 'extractor', None))
ET.SubElement(item_elem, 'path').text = str('/'.join(item.path))
image = item.image
if image is not None:
item_elem.append(self.encode_image(image))
item_elem.append(cls.encode_image(image))
for ann in item.annotations:
item_elem.append(self.encode_object(ann))
item_elem.append(cls.encode_annotation(ann, categories))
return item_elem
@ -52,7 +41,7 @@ class DatasetItemEncoder:
return image_elem
@classmethod
def encode_annotation(cls, annotation):
def encode_annotation_base(cls, annotation):
assert isinstance(annotation, Annotation)
ann_elem = ET.Element('annotation')
ET.SubElement(ann_elem, 'id').text = str(annotation.id)
@ -65,18 +54,31 @@ class DatasetItemEncoder:
return ann_elem
@staticmethod
def _get_label(label_id, categories):
label = ''
if categories is not None:
label_cat = categories.get(AnnotationType.label)
if label_cat is not None:
label = label_cat.items[label_id].name
return label
@classmethod
def encode_label_object(cls, obj):
ann_elem = cls.encode_annotation(obj)
def encode_label_object(cls, obj, categories):
ann_elem = cls.encode_annotation_base(obj)
ET.SubElement(ann_elem, 'label').text = \
str(cls._get_label(obj.label, categories))
ET.SubElement(ann_elem, 'label_id').text = str(obj.label)
return ann_elem
@classmethod
def encode_mask_object(cls, obj):
ann_elem = cls.encode_annotation(obj)
def encode_mask_object(cls, obj, categories):
ann_elem = cls.encode_annotation_base(obj)
ET.SubElement(ann_elem, 'label').text = \
str(cls._get_label(obj.label, categories))
ET.SubElement(ann_elem, 'label_id').text = str(obj.label)
mask = obj.image
@ -86,9 +88,11 @@ class DatasetItemEncoder:
return ann_elem
@classmethod
def encode_bbox_object(cls, obj):
ann_elem = cls.encode_annotation(obj)
def encode_bbox_object(cls, obj, categories):
ann_elem = cls.encode_annotation_base(obj)
ET.SubElement(ann_elem, 'label').text = \
str(cls._get_label(obj.label, categories))
ET.SubElement(ann_elem, 'label_id').text = str(obj.label)
ET.SubElement(ann_elem, 'x').text = str(obj.x)
ET.SubElement(ann_elem, 'y').text = str(obj.y)
@ -99,9 +103,11 @@ class DatasetItemEncoder:
return ann_elem
@classmethod
def encode_points_object(cls, obj):
ann_elem = cls.encode_annotation(obj)
def encode_points_object(cls, obj, categories):
ann_elem = cls.encode_annotation_base(obj)
ET.SubElement(ann_elem, 'label').text = \
str(cls._get_label(obj.label, categories))
ET.SubElement(ann_elem, 'label_id').text = str(obj.label)
x, y, w, h = obj.get_bbox()
@ -113,20 +119,22 @@ class DatasetItemEncoder:
ET.SubElement(bbox_elem, 'h').text = str(h)
ET.SubElement(bbox_elem, 'area').text = str(area)
points = ann_elem.points
points = obj.points
for i in range(0, len(points), 2):
point_elem = ET.SubElement(ann_elem, 'point')
ET.SubElement(point_elem, 'x').text = str(points[i * 2])
ET.SubElement(point_elem, 'y').text = str(points[i * 2 + 1])
ET.SubElement(point_elem, 'x').text = str(points[i])
ET.SubElement(point_elem, 'y').text = str(points[i + 1])
ET.SubElement(point_elem, 'visible').text = \
str(ann_elem.visibility[i // 2].name)
str(obj.visibility[i // 2].name)
return ann_elem
@classmethod
def encode_polyline_object(cls, obj):
ann_elem = cls.encode_annotation(obj)
def encode_polygon_object(cls, obj, categories):
ann_elem = cls.encode_annotation_base(obj)
ET.SubElement(ann_elem, 'label').text = \
str(cls._get_label(obj.label, categories))
ET.SubElement(ann_elem, 'label_id').text = str(obj.label)
x, y, w, h = obj.get_bbox()
@ -138,57 +146,142 @@ class DatasetItemEncoder:
ET.SubElement(bbox_elem, 'h').text = str(h)
ET.SubElement(bbox_elem, 'area').text = str(area)
points = ann_elem.points
points = obj.points
for i in range(0, len(points), 2):
point_elem = ET.SubElement(ann_elem, 'point')
ET.SubElement(point_elem, 'x').text = str(points[i * 2])
ET.SubElement(point_elem, 'y').text = str(points[i * 2 + 1])
ET.SubElement(point_elem, 'x').text = str(points[i])
ET.SubElement(point_elem, 'y').text = str(points[i + 1])
return ann_elem
@classmethod
def encode_polyline_object(cls, obj, categories):
ann_elem = cls.encode_annotation_base(obj)
ET.SubElement(ann_elem, 'label').text = \
str(cls._get_label(obj.label, categories))
ET.SubElement(ann_elem, 'label_id').text = str(obj.label)
x, y, w, h = obj.get_bbox()
area = w * h
bbox_elem = ET.SubElement(ann_elem, 'bbox')
ET.SubElement(bbox_elem, 'x').text = str(x)
ET.SubElement(bbox_elem, 'y').text = str(y)
ET.SubElement(bbox_elem, 'w').text = str(w)
ET.SubElement(bbox_elem, 'h').text = str(h)
ET.SubElement(bbox_elem, 'area').text = str(area)
points = obj.points
for i in range(0, len(points), 2):
point_elem = ET.SubElement(ann_elem, 'point')
ET.SubElement(point_elem, 'x').text = str(points[i])
ET.SubElement(point_elem, 'y').text = str(points[i + 1])
return ann_elem
@classmethod
def encode_caption_object(cls, obj):
ann_elem = cls.encode_annotation(obj)
ann_elem = cls.encode_annotation_base(obj)
ET.SubElement(ann_elem, 'caption').text = str(obj.caption)
return ann_elem
def encode_object(self, o):
@classmethod
def encode_annotation(cls, o, categories=None):
if isinstance(o, LabelObject):
return self.encode_label_object(o)
return cls.encode_label_object(o, categories)
if isinstance(o, MaskObject):
return self.encode_mask_object(o)
return cls.encode_mask_object(o, categories)
if isinstance(o, BboxObject):
return self.encode_bbox_object(o)
return cls.encode_bbox_object(o, categories)
if isinstance(o, PointsObject):
return self.encode_points_object(o)
return cls.encode_points_object(o, categories)
if isinstance(o, PolyLineObject):
return self.encode_polyline_object(o)
return cls.encode_polyline_object(o, categories)
if isinstance(o, PolygonObject):
return self.encode_polygon_object(o)
return cls.encode_polygon_object(o, categories)
if isinstance(o, CaptionObject):
return self.encode_caption_object(o)
if isinstance(o, Annotation): # keep after derived classes
return self.encode_annotation(o)
return cls.encode_caption_object(o)
raise NotImplementedError("Unexpected annotation object passed: %s" % o)
@staticmethod
def to_string(encoded_item):
return ET.tostring(encoded_item, encoding='unicode', pretty_print=True)
def XPathDatasetFilter(extractor, xpath=None):
if xpath is None:
return extractor
xpath = ET.XPath(xpath)
f = lambda item: bool(xpath(
DatasetItemEncoder.encode(item, extractor.categories())))
return extractor.select(f)
class XPathAnnotationsFilter(Extractor): # NOTE: essentially, a transform
class ItemWrapper(DatasetItem):
def __init__(self, item, annotations):
self._item = item
self._annotations = annotations
@DatasetItem.id.getter
def id(self):
return self._item.id
@DatasetItem.subset.getter
def subset(self):
return self._item.subset
if isinstance(o, DatasetItem):
return self.encode_item(o)
@DatasetItem.path.getter
def path(self):
return self._item.path
return None
@DatasetItem.annotations.getter
def annotations(self):
return self._annotations
class XPathDatasetFilter:
def __init__(self, filter_text=None):
self._filter = None
if filter_text is not None:
self._filter = ET.XPath(filter_text)
self._encoder = DatasetItemEncoder()
@DatasetItem.has_image.getter
def has_image(self):
return self._item.has_image
def __call__(self, item):
encoded_item = self._serialize_item(item)
@DatasetItem.image.getter
def image(self):
return self._item.image
def __init__(self, extractor, xpath=None, remove_empty=False):
super().__init__()
self._extractor = extractor
if xpath is not None:
xpath = ET.XPath(xpath)
self._filter = xpath
self._remove_empty = remove_empty
def __len__(self):
return len(self._extractor)
def __iter__(self):
for item in self._extractor:
item = self._filter_item(item)
if item is not None:
yield item
def subsets(self):
return self._extractor.subsets()
def categories(self):
return self._extractor.categories()
def _filter_item(self, item):
if self._filter is None:
return True
return bool(self._filter(encoded_item))
return item
encoded = DatasetItemEncoder.encode(item, self._extractor.categories())
filtered = self._filter(encoded)
filtered = [elem for elem in filtered if elem.tag == 'annotation']
encoded = encoded.findall('annotation')
annotations = [item.annotations[encoded.index(e)] for e in filtered]
def _serialize_item(self, item):
return self._encoder.encode_item(item)
if self._remove_empty and len(annotations) == 0:
return None
return self.ItemWrapper(item, annotations)

@ -580,9 +580,9 @@ class _DatasetFilter:
return filter(self.predicate, self.iterable)
class _ExtractorBase(IExtractor):
def __init__(self, length=None):
def __init__(self, length=None, subsets=None):
self._length = length
self._subsets = None
self._subsets = subsets
def _init_cache(self):
subsets = set()
@ -612,9 +612,12 @@ class _ExtractorBase(IExtractor):
else:
raise Exception("Unknown subset '%s' requested" % name)
def transform(self, method, *args, **kwargs):
return method(self, *args, **kwargs)
class DatasetIteratorWrapper(_ExtractorBase):
def __init__(self, iterable, categories):
super().__init__(length=None)
def __init__(self, iterable, categories, subsets=None):
super().__init__(length=None, subsets=subsets)
self._iterable = iterable
self._categories = categories
@ -626,7 +629,7 @@ class DatasetIteratorWrapper(_ExtractorBase):
def select(self, pred):
return DatasetIteratorWrapper(
_DatasetFilter(self, pred), self.categories())
_DatasetFilter(self, pred), self.categories(), self.subsets())
class Extractor(_ExtractorBase):
def __init__(self, length=None):
@ -637,7 +640,7 @@ class Extractor(_ExtractorBase):
def select(self, pred):
return DatasetIteratorWrapper(
_DatasetFilter(self, pred), self.categories())
_DatasetFilter(self, pred), self.categories(), self.subsets())
DEFAULT_SUBSET_NAME = 'default'

@ -10,7 +10,7 @@ from datumaro.components.extractor import DatasetItem, Extractor
# pylint: disable=no-self-use
class Launcher:
def __init__(self):
def __init__(self, model_dir=None):
pass
def launch(self, inputs):

@ -14,9 +14,10 @@ import sys
from datumaro.components.config import Config, DEFAULT_FORMAT
from datumaro.components.config_model import *
from datumaro.components.extractor import *
from datumaro.components.launcher import *
from datumaro.components.dataset_filter import XPathDatasetFilter
from datumaro.components.extractor import DatasetItem, Extractor
from datumaro.components.launcher import InferenceWrapper
from datumaro.components.dataset_filter import \
XPathDatasetFilter, XPathAnnotationsFilter
def import_foreign_module(name, path):
@ -305,7 +306,131 @@ class DatasetItemWrapper(DatasetItem):
return self._image
return self._item.image
class ProjectDataset(Extractor):
class Dataset(Extractor):
@classmethod
def from_extractors(cls, *sources):
# merge categories
# TODO: implement properly with merging and annotations remapping
categories = {}
for source in sources:
categories.update(source.categories())
for source in sources:
for cat_type, source_cat in source.categories().items():
assert categories[cat_type] == source_cat
dataset = Dataset(categories=categories)
# merge items
subsets = defaultdict(lambda: Subset(dataset))
for source in sources:
for item in source:
path = None # NOTE: merge everything into our own dataset
existing_item = subsets[item.subset].items.get(item.id)
if existing_item is not None:
image = None
if existing_item.has_image:
# TODO: think of image comparison
image = cls._lazy_image(existing_item)
item = DatasetItemWrapper(item=item, path=path,
image=image, annotations=self._merge_anno(
existing_item.annotations, item.annotations))
else:
item = DatasetItemWrapper(item=item, path=path,
annotations=item.annotations)
subsets[item.subset].items[item.id] = item
self._subsets = dict(subsets)
def __init__(self, categories=None):
super().__init__()
self._subsets = {}
if not categories:
categories = {}
self._categories = categories
def __iter__(self):
for subset in self._subsets.values():
for item in subset:
yield item
def __len__(self):
if self._length is None:
self._length = reduce(lambda s, x: s + len(x),
self._subsets.values(), 0)
return self._length
def get_subset(self, name):
return self._subsets[name]
def subsets(self):
return list(self._subsets)
def categories(self):
return self._categories
def get(self, item_id, subset=None, path=None):
if path:
raise KeyError("Requested dataset item path is not found")
return self._subsets[subset].items[item_id]
def put(self, item, item_id=None, subset=None, path=None):
if path:
raise KeyError("Requested dataset item path is not found")
if item_id is None:
item_id = item.id
if subset is None:
subset = item.subset
item = DatasetItemWrapper(item=item, path=None,
annotations=item.annotations)
if item.subset not in self._subsets:
self._subsets[item.subset] = Subset(self)
self._subsets[subset].items[item_id] = item
self._length = None
return item
def extract(self, filter_expr, filter_annotations=False, **kwargs):
if filter_annotations:
return self.transform(XPathAnnotationsFilter, filter_expr, **kwargs)
else:
return self.transform(XPathDatasetFilter, filter_expr, **kwargs)
def update(self, items):
for item in items:
self.put(item)
return self
def define_categories(self, categories):
assert not self._categories
self._categories = categories
@staticmethod
def _lazy_image(item):
# NOTE: avoid https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result
return lambda: item.image
@staticmethod
def _merge_anno(a, b):
from itertools import chain
merged = []
for item in chain(a, b):
found = False
for elem in merged:
if elem == item:
found = True
break
if not found:
merged.append(item)
return merged
class ProjectDataset(Dataset):
def __init__(self, project):
super().__init__()
@ -313,11 +438,6 @@ class ProjectDataset(Extractor):
config = self.config
env = self.env
dataset_filter = None
if config.filter:
dataset_filter = XPathDatasetFilter(config.filter)
self._filter = dataset_filter
sources = {}
for s_name, source in config.sources.items():
s_format = source.format
@ -335,7 +455,7 @@ class ProjectDataset(Extractor):
own_source = None
own_source_dir = osp.join(config.project_dir, config.dataset_dir)
if osp.isdir(config.project_dir) and osp.isdir(own_source_dir):
if config.project_dir and osp.isdir(own_source_dir):
log.disable(log.INFO)
own_source = env.make_importer(DEFAULT_FORMAT)(own_source_dir) \
.make_dataset()
@ -358,9 +478,6 @@ class ProjectDataset(Extractor):
for source_name, source in self._sources.items():
log.debug("Loading '%s' source contents..." % source_name)
for item in source:
if dataset_filter and not dataset_filter(item):
continue
existing_item = subsets[item.subset].items.get(item.id)
if existing_item is not None:
image = None
@ -370,14 +487,14 @@ class ProjectDataset(Extractor):
path = existing_item.path
if item.path != path:
path = None
path = None # NOTE: move to our own dataset
item = DatasetItemWrapper(item=item, path=path,
image=image, annotations=self._merge_anno(
existing_item.annotations, item.annotations))
else:
s_config = config.sources[source_name]
if s_config and \
s_config.format != self.env.PROJECT_EXTRACTOR_NAME:
s_config.format != env.PROJECT_EXTRACTOR_NAME:
# NOTE: consider imported sources as our own dataset
path = None
else:
@ -394,9 +511,6 @@ class ProjectDataset(Extractor):
if own_source is not None:
log.debug("Loading own dataset...")
for item in own_source:
if dataset_filter and not dataset_filter(item):
continue
if not item.has_image:
existing_item = subsets[item.subset].items.get(item.id)
if existing_item is not None:
@ -417,55 +531,9 @@ class ProjectDataset(Extractor):
self._length = None
@staticmethod
def _lazy_image(item):
# NOTE: avoid https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result
return lambda: item.image
@staticmethod
def _merge_anno(a, b):
from itertools import chain
merged = []
for item in chain(a, b):
found = False
for elem in merged:
if elem == item:
found = True
break
if not found:
merged.append(item)
return merged
def iterate_own(self):
return self.select(lambda item: not item.path)
def __iter__(self):
for subset in self._subsets.values():
for item in subset:
if self._filter and not self._filter(item):
continue
yield item
def __len__(self):
if self._length is None:
self._length = reduce(lambda s, x: s + len(x),
self._subsets.values(), 0)
return self._length
def get_subset(self, name):
return self._subsets[name]
def subsets(self):
return list(self._subsets)
def categories(self):
return self._categories
def define_categories(self, categories):
assert not self._categories
self._categories = categories
def get(self, item_id, subset=None, path=None):
if path:
source = path[0]
@ -498,54 +566,6 @@ class ProjectDataset(Extractor):
return item
def build(self, tasks=None):
pass
def docs(self):
pass
def transform(self, model_name, save_dir=None):
project = Project(self.config)
project.config.remove('sources')
if save_dir is None:
save_dir = self.config.project_dir
project.config.project_dir = save_dir
dataset = project.make_dataset()
launcher = self._project.make_executable_model(model_name)
inference = InferenceWrapper(self, launcher)
dataset.update(inference)
dataset.save(merge=True)
def export(self, save_dir, output_format,
filter_expr=None, **converter_kwargs):
save_dir = osp.abspath(save_dir)
os.makedirs(save_dir, exist_ok=True)
dataset = self
if filter_expr:
dataset_filter = XPathDatasetFilter(filter_expr)
dataset = dataset.select(dataset_filter)
converter = self.env.make_converter(output_format, **converter_kwargs)
converter(dataset, save_dir)
def extract(self, save_dir, filter_expr=None):
project = Project(self.config)
if filter_expr:
XPathDatasetFilter(filter_expr)
project.set_filter(filter_expr)
project.save(save_dir)
def update(self, items):
for item in items:
if self._filter and not self._filter(item):
continue
self.put(item)
return self
def save(self, save_dir=None, merge=False, recursive=True,
save_images=False):
if save_dir is None:
@ -600,6 +620,60 @@ class ProjectDataset(Extractor):
def sources(self):
return self._sources
def _save_branch_project(self, extractor, save_dir=None):
# NOTE: probably this function should be in the ViewModel layer
save_dir = osp.abspath(save_dir)
if save_dir:
dst_project = Project()
else:
if not self.config.project_dir:
raise Exception("Either a save directory or a project "
"directory should be specified")
save_dir = self.config.project_dir
dst_project = Project(Config(self.config))
dst_project.config.remove('project_dir')
dst_project.config.remove('sources')
dst_dataset = dst_project.make_dataset()
dst_dataset.define_categories(extractor.categories())
dst_dataset.update(extractor)
dst_dataset.save(save_dir=save_dir, merge=True)
def transform_project(self, method, *args, save_dir=None, **kwargs):
# NOTE: probably this function should be in the ViewModel layer
transformed = self.transform(method, *args, **kwargs)
self._save_branch_project(transformed, save_dir=save_dir)
def apply_model(self, model_name, save_dir=None):
# NOTE: probably this function should be in the ViewModel layer
launcher = self._project.make_executable_model(model_name)
self.transform_project(InferenceWrapper, launcher, save_dir=save_dir)
def export_project(self, save_dir, output_format,
filter_expr=None, filter_annotations=False, **converter_kwargs):
# NOTE: probably this function should be in the ViewModel layer
save_dir = osp.abspath(save_dir)
os.makedirs(save_dir, exist_ok=True)
dataset = self
if filter_expr:
dataset = dataset.extract(filter_expr, filter_annotations)
converter = self.env.make_converter(output_format, **converter_kwargs)
converter(dataset, save_dir)
def extract_project(self, filter_expr, filter_annotations=False,
save_dir=None, remove_empty=False):
# NOTE: probably this function should be in the ViewModel layer
filtered = self
if filter_expr:
filtered = self.extract(filter_expr,
filter_annotations=filter_annotations,
remove_empty=remove_empty)
self._save_branch_project(filtered, save_dir=save_dir)
class Project:
@staticmethod
def load(path):
@ -697,24 +771,10 @@ class Project:
config = Config(self.config)
config.remove('sources')
config.remove('subsets')
config.remove('filter')
project = Project(config)
project.add_source(name, source)
return project
def get_filter(self):
if 'filter' in self.config:
return self.config.filter
return ''
def set_filter(self, value=None):
if not value:
self.config.remove('filter')
else:
# check filter
XPathDatasetFilter(value)
self.config.filter = value
def local_model_dir(self, model_name):
return osp.join(
self.config.env_dir, self.env.config.models_dir, model_name)
@ -726,4 +786,4 @@ class Project:
def load_project_as_dataset(url):
# implement the function declared above
return Project.load(url).make_dataset()
# pylint: enable=function-redefined
# pylint: enable=function-redefined

@ -1,3 +1,4 @@
import numpy as np
import os
import os.path as osp
@ -7,9 +8,13 @@ from datumaro.components.project import Project, Environment
from datumaro.components.project import Source, Model
from datumaro.components.launcher import Launcher, InferenceWrapper
from datumaro.components.converter import Converter
from datumaro.components.extractor import Extractor, DatasetItem, LabelObject
from datumaro.components.extractor import (Extractor, DatasetItem,
LabelObject, MaskObject, PointsObject, PolygonObject,
PolyLineObject, BboxObject, CaptionObject,
)
from datumaro.components.config import Config, DefaultConfig, SchemaBuilder
from datumaro.components.dataset_filter import XPathDatasetFilter
from datumaro.components.dataset_filter import \
XPathDatasetFilter, XPathAnnotationsFilter, DatasetItemEncoder
from datumaro.util.test_utils import TestDir
@ -129,18 +134,11 @@ class ProjectTest(TestCase):
def test_can_batch_launch_custom_model(self):
class TestExtractor(Extractor):
def __init__(self, url, n=0):
super().__init__(length=n)
self.n = n
def __iter__(self):
for i in range(self.n):
for i in range(5):
yield DatasetItem(id=i, subset='train', image=i)
class TestLauncher(Launcher):
def __init__(self, **kwargs):
pass
def launch(self, inputs):
for i, inp in enumerate(inputs):
yield [ LabelObject(attributes={'idx': i, 'data': inp}) ]
@ -152,7 +150,7 @@ class ProjectTest(TestCase):
project.env.launchers.register(launcher_name, TestLauncher)
project.add_model(model_name, { 'launcher': launcher_name })
model = project.make_executable_model(model_name)
extractor = TestExtractor('', n=5)
extractor = TestExtractor()
batch_size = 3
executor = InferenceWrapper(extractor, model, batch_size=batch_size)
@ -166,19 +164,12 @@ class ProjectTest(TestCase):
def test_can_do_transform_with_custom_model(self):
class TestExtractorSrc(Extractor):
def __init__(self, url, n=2):
super().__init__(length=n)
self.n = n
def __iter__(self):
for i in range(self.n):
for i in range(2):
yield DatasetItem(id=i, subset='train', image=i,
annotations=[ LabelObject(i) ])
class TestLauncher(Launcher):
def __init__(self, **kwargs):
pass
def launch(self, inputs):
for inp in inputs:
yield [ LabelObject(inp) ]
@ -186,7 +177,7 @@ class ProjectTest(TestCase):
class TestConverter(Converter):
def __call__(self, extractor, save_dir):
for item in extractor:
with open(osp.join(save_dir, '%s.txt' % item.id), 'w+') as f:
with open(osp.join(save_dir, '%s.txt' % item.id), 'w') as f:
f.write(str(item.subset) + '\n')
f.write(str(item.annotations[0].label) + '\n')
@ -199,8 +190,8 @@ class ProjectTest(TestCase):
for path in self.items:
with open(path, 'r') as f:
index = osp.splitext(osp.basename(path))[0]
subset = f.readline()[:-1]
label = int(f.readline()[:-1])
subset = f.readline().strip()
label = int(f.readline().strip())
assert subset == 'train'
yield DatasetItem(id=index, subset=subset,
annotations=[ LabelObject(label) ])
@ -217,7 +208,8 @@ class ProjectTest(TestCase):
project.add_source('source', { 'format': extractor_name })
with TestDir() as test_dir:
project.make_dataset().transform(model_name, test_dir.path)
project.make_dataset().apply_model(model_name=model_name,
save_dir=test_dir.path)
result = Project.load(test_dir.path)
result.env.extractors.register(extractor_name, TestExtractorDst)
@ -255,21 +247,16 @@ class ProjectTest(TestCase):
def test_project_filter_can_be_applied(self):
class TestExtractor(Extractor):
def __init__(self, url, n=10):
super().__init__(length=n)
self.n = n
def __iter__(self):
for i in range(self.n):
for i in range(10):
yield DatasetItem(id=i, subset='train')
e_type = 'type'
project = Project()
project.env.extractors.register(e_type, TestExtractor)
project.add_source('source', { 'format': e_type })
project.set_filter('/item[id < 5]')
dataset = project.make_dataset()
dataset = project.make_dataset().extract('/item[id < 5]')
self.assertEqual(5, len(dataset))
@ -326,30 +313,23 @@ class ProjectTest(TestCase):
self.assertEqual(1, len(dataset.sources['child2']))
def test_project_can_merge_item_annotations(self):
class TestExtractor(Extractor):
def __init__(self, url, v=None):
super().__init__()
self.v = v
class TestExtractor1(Extractor):
def __iter__(self):
v1_item = DatasetItem(id=1, subset='train', annotations=[
yield DatasetItem(id=1, subset='train', annotations=[
LabelObject(2, id=3),
LabelObject(3, attributes={ 'x': 1 }),
])
v2_item = DatasetItem(id=1, subset='train', annotations=[
class TestExtractor2(Extractor):
def __iter__(self):
yield DatasetItem(id=1, subset='train', annotations=[
LabelObject(3, attributes={ 'x': 1 }),
LabelObject(4, id=4),
])
if self.v == 1:
yield v1_item
else:
yield v2_item
project = Project()
project.env.extractors.register('t1', lambda p: TestExtractor(p, v=1))
project.env.extractors.register('t2', lambda p: TestExtractor(p, v=2))
project.env.extractors.register('t1', TestExtractor1)
project.env.extractors.register('t2', TestExtractor2)
project.add_source('source1', { 'format': 't1' })
project.add_source('source2', { 'format': 't2' })
@ -361,23 +341,103 @@ class ProjectTest(TestCase):
self.assertEqual(3, len(item.annotations))
class DatasetFilterTest(TestCase):
class TestExtractor(Extractor):
def __init__(self, url, n=0):
super().__init__(length=n)
self.n = n
def __iter__(self):
for i in range(self.n):
yield DatasetItem(id=i, subset='train')
@staticmethod
def test_item_representations():
item = DatasetItem(id=1, subset='subset', path=['a', 'b'],
image=np.ones((5, 4, 3)),
annotations=[
LabelObject(0, attributes={'a1': 1, 'a2': '2'}, id=1, group=2),
CaptionObject('hello', id=1),
CaptionObject('world', group=5),
LabelObject(2, id=3, attributes={ 'x': 1, 'y': '2' }),
BboxObject(1, 2, 3, 4, label=4, id=4, attributes={ 'a': 1.0 }),
BboxObject(5, 6, 7, 8, id=5, group=5),
PointsObject([1, 2, 2, 0, 1, 1], label=0, id=5),
MaskObject(label=3, id=5, image=np.ones((2, 3))),
PolyLineObject([1, 2, 3, 4, 5, 6, 7, 8], id=11),
PolygonObject([1, 2, 3, 4, 5, 6, 7, 8]),
]
)
encoded = DatasetItemEncoder.encode(item)
DatasetItemEncoder.to_string(encoded)
def test_item_filter_can_be_applied(self):
class TestExtractor(Extractor):
def __iter__(self):
for i in range(4):
yield DatasetItem(id=i, subset='train')
def test_xpathfilter_can_be_applied(self):
extractor = self.TestExtractor('', n=4)
dataset_filter = XPathDatasetFilter('/item[id > 1]')
extractor = TestExtractor()
filtered = extractor.select(dataset_filter)
filtered = XPathDatasetFilter(extractor, '/item[id > 1]')
self.assertEqual(2, len(filtered))
def test_annotations_filter_can_be_applied(self):
class SrcTestExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=0),
DatasetItem(id=1, annotations=[
LabelObject(0),
LabelObject(1),
]),
DatasetItem(id=2, annotations=[
LabelObject(0),
LabelObject(2),
]),
])
class DstTestExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=0),
DatasetItem(id=1, annotations=[
LabelObject(0),
]),
DatasetItem(id=2, annotations=[
LabelObject(0),
]),
])
extractor = SrcTestExtractor()
filtered = XPathAnnotationsFilter(extractor,
'/item/annotation[label_id = 0]')
self.assertListEqual(list(filtered), list(DstTestExtractor()))
def test_annotations_filter_can_remove_empty_items(self):
class SrcTestExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=0),
DatasetItem(id=1, annotations=[
LabelObject(0),
LabelObject(1),
]),
DatasetItem(id=2, annotations=[
LabelObject(0),
LabelObject(2),
]),
])
class DstTestExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=2, annotations=[
LabelObject(2),
]),
])
extractor = SrcTestExtractor()
filtered = XPathAnnotationsFilter(extractor,
'/item/annotation[label_id = 2]', remove_empty=True)
self.assertListEqual(list(filtered), list(DstTestExtractor()))
class ConfigTest(TestCase):
def test_can_produce_multilayer_config_from_dict(self):
schema_low = SchemaBuilder() \
@ -409,9 +469,6 @@ class ConfigTest(TestCase):
class ExtractorTest(TestCase):
def test_custom_extractor_can_be_created(self):
class CustomExtractor(Extractor):
def __init__(self, url):
super().__init__()
def __iter__(self):
return iter([
DatasetItem(id=0, subset='train'),

Loading…
Cancel
Save