diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py index 7e20fc33..64921555 100644 --- a/cvat/apps/dataset_manager/bindings.py +++ b/cvat/apps/dataset_manager/bindings.py @@ -531,7 +531,7 @@ class CvatTaskDataExtractor(datumaro.SourceExtractor): return dm_attr for tag_obj in cvat_frame_anno.tags: - anno_group = tag_obj.group + anno_group = tag_obj.group or 0 anno_label = map_label(tag_obj.label) anno_attr = convert_attrs(tag_obj.label, tag_obj.attributes) @@ -540,7 +540,7 @@ class CvatTaskDataExtractor(datumaro.SourceExtractor): item_anno.append(anno) for shape_obj in cvat_frame_anno.labeled_shapes: - anno_group = shape_obj.group + anno_group = shape_obj.group or 0 anno_label = map_label(shape_obj.label) anno_attr = convert_attrs(shape_obj.label, shape_obj.attributes) anno_attr['occluded'] = shape_obj.occluded diff --git a/datumaro/datumaro/components/dataset_filter.py b/datumaro/datumaro/components/dataset_filter.py index 351cc6c8..e9fc5e35 100644 --- a/datumaro/datumaro/components/dataset_filter.py +++ b/datumaro/datumaro/components/dataset_filter.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: MIT import logging as log -from lxml import etree as ET # NOTE: lxml has proper XPath implementation +from lxml import etree as ET # lxml has proper XPath implementation from datumaro.components.extractor import (Transform, Annotation, AnnotationType, Label, Mask, Points, Polygon, PolyLine, Bbox, Caption, diff --git a/datumaro/datumaro/components/extractor.py b/datumaro/datumaro/components/extractor.py index 1df3c287..609c6e9e 100644 --- a/datumaro/datumaro/components/extractor.py +++ b/datumaro/datumaro/components/extractor.py @@ -7,7 +7,11 @@ from collections import namedtuple from enum import Enum import numpy as np +import attr +from attr import attrs, attrib + from datumaro.util.image import Image +from datumaro.util.attrs_util import not_empty, default_if_none AnnotationType = Enum('AnnotationType', [ @@ -22,59 +26,34 @@ AnnotationType = Enum('AnnotationType', _COORDINATE_ROUNDING_DIGITS = 2 - +@attrs class Annotation: - # pylint: disable=redefined-builtin - def __init__(self, id=None, type=None, attributes=None, group=None): - if id is not None: - id = int(id) - self.id = id + id = attrib(converter=int, default=0, kw_only=True) + attributes = attrib(converter=dict, factory=dict, kw_only=True) + group = attrib(converter=int, default=0, kw_only=True) - assert type in AnnotationType - self.type = type - - if attributes is None: - attributes = {} - else: - attributes = dict(attributes) - self.attributes = attributes + def __attrs_post_init__(self): + assert isinstance(self.type, AnnotationType) - if group is None: - group = 0 - else: - group = int(group) - self.group = group - # pylint: enable=redefined-builtin + @property + def type(self): + return self._type # must be set in subclasses - def __eq__(self, other): - if not isinstance(other, Annotation): - return False - return \ - (self.id == other.id) and \ - (self.type == other.type) and \ - (self.attributes == other.attributes) and \ - (self.group == other.group) + def wrap(item, **kwargs): + return attr.evolve(item, **kwargs) +@attrs class Categories: - def __init__(self, attributes=None): - if attributes is None: - attributes = set() - else: - if not isinstance(attributes, set): - attributes = set(attributes) - for attr in attributes: - assert isinstance(attr, str) - self.attributes = attributes - - def __eq__(self, other): - if not isinstance(other, Categories): - return False - return \ - (self.attributes == other.attributes) + attributes = attrib(factory=set, validator=default_if_none(set), + kw_only=True) +@attrs class LabelCategories(Categories): Category = namedtuple('Category', ['name', 'parent', 'attributes']) + items = attrib(factory=list, validator=default_if_none(list)) + _indices = attrib(factory=dict, init=False, eq=False) + @classmethod def from_iterable(cls, iterable): """Generation of LabelCategories from iterable object @@ -102,14 +81,7 @@ class LabelCategories(Categories): return temp_categories - def __init__(self, items=None, attributes=None): - super().__init__(attributes=attributes) - - if items is None: - items = [] - self.items = items - - self._indices = {} + def __attrs_post_init__(self): self._reindex() def _reindex(self): @@ -142,49 +114,23 @@ class LabelCategories(Categories): return index, self.items[index] return index, None - def __eq__(self, other): - if not super().__eq__(other): - return False - return \ - (self.items == other.items) - +@attrs class Label(Annotation): - # pylint: disable=redefined-builtin - def __init__(self, label=None, - id=None, attributes=None, group=None): - super().__init__(id=id, type=AnnotationType.label, - attributes=attributes, group=group) - - if label is not None: - label = int(label) - self.label = label - # pylint: enable=redefined-builtin - - def __eq__(self, other): - if not super().__eq__(other): - return False - return \ - (self.label == other.label) + _type = AnnotationType.label + label = attrib(converter=int) +@attrs(eq=False) class MaskCategories(Categories): - def __init__(self, colormap=None, inverse_colormap=None, attributes=None): - super().__init__(attributes=attributes) - - # colormap: label id -> color - if colormap is None: - colormap = {} - self.colormap = colormap - self._inverse_colormap = inverse_colormap + colormap = attrib(factory=dict, validator=default_if_none(dict)) + _inverse_colormap = attrib(default=None, + validator=attr.validators.optional(dict)) @property def inverse_colormap(self): from datumaro.util.mask_tools import invert_colormap if self._inverse_colormap is None: if self.colormap is not None: - try: - self._inverse_colormap = invert_colormap(self.colormap) - except Exception: - pass + self._inverse_colormap = invert_colormap(self.colormap) return self._inverse_colormap def __eq__(self, other): @@ -196,25 +142,13 @@ class MaskCategories(Categories): return False return True +@attrs(eq=False) class Mask(Annotation): - # pylint: disable=redefined-builtin - def __init__(self, image=None, label=None, z_order=None, - id=None, attributes=None, group=None): - super().__init__(type=AnnotationType.mask, - id=id, attributes=attributes, group=group) - - self._image = image - - if label is not None: - label = int(label) - self._label = label - - if z_order is None: - z_order = 0 - else: - z_order = int(z_order) - self._z_order = z_order - # pylint: enable=redefined-builtin + _type = AnnotationType.mask + _image = attrib() + label = attrib(converter=attr.converters.optional(int), + default=None, kw_only=True) + z_order = attrib(converter=int, default=0, kw_only=True) @property def image(self): @@ -222,14 +156,6 @@ class Mask(Annotation): return self._image() return self._image - @property - def label(self): - return self._label - - @property - def z_order(self): - return self._z_order - def as_class_mask(self, label_id=None): if label_id is None: label_id = self.label @@ -255,19 +181,14 @@ class Mask(Annotation): return \ (self.label == other.label) and \ (self.z_order == other.z_order) and \ - (self.image is not None and other.image is not None and \ - np.array_equal(self.image, other.image)) + (np.array_equal(self.image, other.image)) +@attrs(eq=False) class RleMask(Mask): - # pylint: disable=redefined-builtin - def __init__(self, rle=None, label=None, z_order=None, - id=None, attributes=None, group=None): - lazy_decode = self._lazy_decode(rle) - super().__init__(image=lazy_decode, label=label, z_order=z_order, - id=id, attributes=attributes, group=group) - - self._rle = rle - # pylint: enable=redefined-builtin + rle = attrib() + _image = attrib(default=attr.Factory( + lambda self: self._lazy_decode(self.rle), + takes_self=True), init=False) @staticmethod def _lazy_decode(rle): @@ -276,20 +197,16 @@ class RleMask(Mask): def get_area(self): from pycocotools import mask as mask_utils - return mask_utils.area(self._rle) + return mask_utils.area(self.rle) def get_bbox(self): from pycocotools import mask as mask_utils - return mask_utils.toBbox(self._rle) - - @property - def rle(self): - return self._rle + return mask_utils.toBbox(self.rle) def __eq__(self, other): if not isinstance(other, __class__): return super().__eq__(other) - return self._rle == other._rle + return self.rle == other.rle class CompiledMask: @staticmethod @@ -375,38 +292,13 @@ def compute_iou(bbox_a, bbox_b): return intersection / max(1.0, union) +@attrs class _Shape(Annotation): - # pylint: disable=redefined-builtin - def __init__(self, type, points=None, label=None, z_order=None, - id=None, attributes=None, group=None): - super().__init__(id=id, type=type, - attributes=attributes, group=group) - if points is not None: - points = [round(p, _COORDINATE_ROUNDING_DIGITS) for p in points] - self._points = points - - if label is not None: - label = int(label) - self._label = label - - if z_order is None: - z_order = 0 - else: - z_order = int(z_order) - self._z_order = z_order - # pylint: enable=redefined-builtin - - @property - def points(self): - return self._points - - @property - def label(self): - return self._label - - @property - def z_order(self): - return self._z_order + points = attrib(converter=lambda x: + [round(p, _COORDINATE_ROUNDING_DIGITS) for p in x]) + label = attrib(converter=attr.converters.optional(int), + default=None, kw_only=True) + z_order = attrib(converter=int, default=0, kw_only=True) def get_area(self): raise NotImplementedError() @@ -424,22 +316,9 @@ class _Shape(Annotation): y1 = max(ys) return [x0, y0, x1 - x0, y1 - y0] - def __eq__(self, other): - if not super().__eq__(other): - return False - return \ - (np.array_equal(self.points, other.points)) and \ - (self.z_order == other.z_order) and \ - (self.label == other.label) - +@attrs class PolyLine(_Shape): - # pylint: disable=redefined-builtin - def __init__(self, points=None, label=None, z_order=None, - id=None, attributes=None, group=None): - super().__init__(type=AnnotationType.polyline, - points=points, label=label, z_order=z_order, - id=id, attributes=attributes, group=group) - # pylint: enable=redefined-builtin + _type = AnnotationType.polyline def as_polygon(self): return self.points[:] @@ -447,18 +326,14 @@ class PolyLine(_Shape): def get_area(self): return 0 +@attrs class Polygon(_Shape): - # pylint: disable=redefined-builtin - def __init__(self, points=None, label=None, - z_order=None, id=None, attributes=None, group=None): - if points is not None: - # keep the message on the single line to produce - # informative output - assert len(points) % 2 == 0 and 3 <= len(points) // 2, "Wrong polygon points: %s" % points - super().__init__(type=AnnotationType.polygon, - points=points, label=label, z_order=z_order, - id=id, attributes=attributes, group=group) - # pylint: enable=redefined-builtin + _type = AnnotationType.polygon + + def __attrs_post_init__(self): + super().__attrs_post_init__() + # keep the message on a single line to produce informative output + assert len(self.points) % 2 == 0 and 3 <= len(self.points) // 2, "Wrong polygon points: %s" % self.points def get_area(self): import pycocotools.mask as mask_utils @@ -468,14 +343,16 @@ class Polygon(_Shape): area = mask_utils.area(rle)[0] return area +@attrs class Bbox(_Shape): - # pylint: disable=redefined-builtin - def __init__(self, x=0, y=0, w=0, h=0, label=None, z_order=None, - id=None, attributes=None, group=None): - super().__init__(type=AnnotationType.bbox, - points=[x, y, x + w, y + h], label=label, z_order=z_order, - id=id, attributes=attributes, group=group) - # pylint: enable=redefined-builtin + _type = AnnotationType.bbox + + # will be overridden by attrs, then will be overridden again by us + # attrs' method will be renamed to __attrs_init__ + def __init__(self, x, y, w, h, *args, **kwargs): + kwargs.pop('points', None) # comes from wrap() + self.__attrs_init__([x, y, x + w, y + h], *args, **kwargs) + __actual_init__ = __init__ # save pointer @property def x(self): @@ -511,9 +388,21 @@ class Bbox(_Shape): def iou(self, other): return compute_iou(self.get_bbox(), other.get_bbox()) + def wrap(item, **kwargs): + d = {'x': item.x, 'y': item.y, 'w': item.w, 'h': item.h} + d.update(kwargs) + return attr.evolve(item, **d) + +assert not hasattr(Bbox, '__attrs_init__') # hopefully, it will be supported +setattr(Bbox, '__attrs_init__', Bbox.__init__) +setattr(Bbox, '__init__', Bbox.__actual_init__) + +@attrs class PointsCategories(Categories): Category = namedtuple('Category', ['labels', 'joints']) + items = attrib(factory=dict, validator=default_if_none(dict)) + @classmethod def from_iterable(cls, iterable): """Generation of PointsCategories from iterable object @@ -539,13 +428,6 @@ class PointsCategories(Categories): temp_categories.add(*category) return temp_categories - def __init__(self, items=None, attributes=None): - super().__init__(attributes=attributes) - - if items is None: - items = {} - self.items = items - def add(self, label_id, labels=None, joints=None): if labels is None: labels = [] @@ -554,41 +436,30 @@ class PointsCategories(Categories): joints = set(map(tuple, joints)) self.items[label_id] = self.Category(labels, joints) - def __eq__(self, other): - if not super().__eq__(other): - return False - return \ - (self.items == other.items) - +@attrs class Points(_Shape): Visibility = Enum('Visibility', [ ('absent', 0), ('hidden', 1), ('visible', 2), ]) + _type = AnnotationType.points - # pylint: disable=redefined-builtin - def __init__(self, points=None, visibility=None, label=None, z_order=None, - id=None, attributes=None, group=None): - if points is not None: - assert len(points) % 2 == 0 - - if visibility is not None: - assert len(visibility) == len(points) // 2 - for i, v in enumerate(visibility): - if not isinstance(v, self.Visibility): - visibility[i] = self.Visibility(v) - else: - visibility = [] - for _ in range(len(points) // 2): - visibility.append(self.Visibility.visible) - - super().__init__(type=AnnotationType.points, - points=points, label=label, z_order=z_order, - id=id, attributes=attributes, group=group) - + visibility = attrib(type=list, default=None) + @visibility.validator + def _visibility_validator(self, attribute, visibility): + if visibility is None: + visibility = [self.Visibility.visible] * (len(self.points) // 2) + else: + for i, v in enumerate(visibility): + if not isinstance(v, self.Visibility): + visibility[i] = self.Visibility(v) + assert len(visibility) == len(self.points) // 2 self.visibility = visibility - # pylint: enable=redefined-builtin + + def __attrs_post_init__(self): + super().__attrs_post_init__() + assert len(self.points) % 2 == 0, self.points def get_area(self): return 0 @@ -604,117 +475,37 @@ class Points(_Shape): y1 = max(ys, default=0) return [x0, y0, x1 - x0, y1 - y0] - def __eq__(self, other): - if not super().__eq__(other): - return False - return \ - (self.visibility == other.visibility) - +@attrs class Caption(Annotation): - # pylint: disable=redefined-builtin - def __init__(self, caption=None, - id=None, attributes=None, group=None): - super().__init__(id=id, type=AnnotationType.caption, - attributes=attributes, group=group) - - if caption is None: - caption = '' - else: - caption = str(caption) - self.caption = caption - # pylint: enable=redefined-builtin - - def __eq__(self, other): - if not super().__eq__(other): - return False - return \ - (self.caption == other.caption) + _type = AnnotationType.caption + caption = attrib(converter=str) +@attrs class DatasetItem: - # pylint: disable=redefined-builtin - def __init__(self, id=None, annotations=None, - subset=None, path=None, image=None, attributes=None): - assert id is not None - self._id = str(id).replace('\\', '/') - - if subset is None: - subset = '' - else: - subset = str(subset) - self._subset = subset - - if path is None: - path = [] - else: - path = list(path) - self._path = path - - if annotations is None: - annotations = [] - else: - annotations = list(annotations) - self._annotations = annotations - + id = attrib(converter=lambda x: str(x).replace('\\', '/'), + type=str, validator=not_empty) + annotations = attrib(factory=list, validator=default_if_none(list)) + subset = attrib(default='', validator=default_if_none(str)) + path = attrib(factory=list, validator=default_if_none(list)) + + image = attrib(type=Image, default=None) + @image.validator + def _image_validator(self, attribute, image): if callable(image) or isinstance(image, np.ndarray): image = Image(data=image) elif isinstance(image, str): image = Image(path=image) assert image is None or isinstance(image, Image) - self._image = image + self.image = image - if attributes is None: - attributes = {} - else: - attributes = dict(attributes) - self._attributes = attributes - # pylint: enable=redefined-builtin - - @property - def id(self): - return self._id - - @property - def subset(self): - return self._subset - - @property - def path(self): - return self._path - - @property - def annotations(self): - return self._annotations - - @property - def image(self): - return self._image + attributes = attrib(factory=dict, validator=default_if_none(dict)) @property def has_image(self): - return self._image is not None - - @property - def attributes(self): - return self._attributes - - def __eq__(self, other): - if not isinstance(other, __class__): - return False - return \ - (self.id == other.id) and \ - (self.subset == other.subset) and \ - (self.path == other.path) and \ - (self.annotations == other.annotations) and \ - (self.image == other.image) and \ - (self.attributes == other.attributes) + return self.image is not None def wrap(item, **kwargs): - expected_args = {'id', 'annotations', 'subset', - 'path', 'image', 'attributes'} - for k in expected_args: - if k not in kwargs: - kwargs[k] = getattr(item, k) - return DatasetItem(**kwargs) + return attr.evolve(item, **kwargs) class IExtractor: def __iter__(self): diff --git a/datumaro/datumaro/plugins/cvat_format/extractor.py b/datumaro/datumaro/plugins/cvat_format/extractor.py index 2c63bd2e..75a3e5d8 100644 --- a/datumaro/datumaro/plugins/cvat_format/extractor.py +++ b/datumaro/datumaro/plugins/cvat_format/extractor.py @@ -102,7 +102,7 @@ class CvatExtractor(SourceExtractor): else: try: attr_value = float(attr_value) - except Exception: + except ValueError: pass attributes[el.attrib['name']] = attr_value elif el.tag in cls._SUPPORTED_SHAPES: @@ -251,7 +251,7 @@ class CvatExtractor(SourceExtractor): @classmethod def _parse_shape_ann(cls, ann, categories): - ann_id = ann.get('id') + ann_id = ann.get('id', 0) ann_type = ann['type'] attributes = ann.get('attributes') or {} diff --git a/datumaro/datumaro/plugins/labelme_format.py b/datumaro/datumaro/plugins/labelme_format.py index d8c71716..5218e36f 100644 --- a/datumaro/datumaro/plugins/labelme_format.py +++ b/datumaro/datumaro/plugins/labelme_format.py @@ -83,7 +83,7 @@ class LabelMeExtractor(SourceExtractor): else: try: value = float(value) - except Exception: + except ValueError: pass parsed.append((name, value)) else: diff --git a/datumaro/datumaro/plugins/transforms.py b/datumaro/datumaro/plugins/transforms.py index 79f0d3ae..b31b4762 100644 --- a/datumaro/datumaro/plugins/transforms.py +++ b/datumaro/datumaro/plugins/transforms.py @@ -504,7 +504,6 @@ class RemapLabels(Transform, CliPlugin): return self._categories def transform_item(self, item): - # TODO: provide non-inplace version annotations = [] for ann in item.annotations: if ann.type in { AnnotationType.label, AnnotationType.mask, @@ -513,9 +512,7 @@ class RemapLabels(Transform, CliPlugin): } and ann.label is not None: conv_label = self._map_id(ann.label) if conv_label is not None: - ann._label = conv_label - annotations.append(ann) + annotations.append(ann.wrap(label=conv_label)) else: - annotations.append(ann) - item._annotations = annotations - return item \ No newline at end of file + annotations.append(ann.wrap()) + return item.wrap(annotations=annotations) \ No newline at end of file diff --git a/datumaro/datumaro/util/attrs_util.py b/datumaro/datumaro/util/attrs_util.py new file mode 100644 index 00000000..af92c549 --- /dev/null +++ b/datumaro/datumaro/util/attrs_util.py @@ -0,0 +1,26 @@ + +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import attr + +_NOTSET = object() + +def not_empty(inst, attribute, x): + assert len(x) != 0, x + +def default_if_none(conv): + def validator(inst, attribute, value): + default = attribute.default + if value is None: + if callable(default): + value = default() + elif isinstance(default, attr.Factory): + value = default.factory() + else: + value = default + elif not isinstance(value, attribute.type or conv): + value = conv(value) + setattr(inst, attribute.name, value) + return validator \ No newline at end of file diff --git a/datumaro/datumaro/util/test_utils.py b/datumaro/datumaro/util/test_utils.py index f9ce0369..cca95278 100644 --- a/datumaro/datumaro/util/test_utils.py +++ b/datumaro/datumaro/util/test_utils.py @@ -43,19 +43,6 @@ class TestDir(FileRemover): super().__init__(path, is_dir=True, ignore_errors=ignore_errors) -def ann_to_str(ann): - return vars(ann) - -def item_to_str(item): - return '\n'.join( - [ - '%s' % vars(item) - ] + [ - 'ann[%s]: %s' % (i, ann_to_str(a)) - for i, a in enumerate(item.annotations) - ] - ) - def compare_categories(test, expected, actual): test.assertEqual( sorted(expected, key=lambda t: t.value), @@ -92,12 +79,11 @@ def compare_datasets(test, expected, actual): for ann_a in item_a.annotations: # We might find few corresponding items, so check them all ann_b_matches = [x for x in item_b.annotations - if x.id == ann_a.id and \ - x.type == ann_a.type and x.group == ann_a.group] + if x.type == ann_a.type] test.assertFalse(len(ann_b_matches) == 0, 'ann id: %s' % ann_a.id) ann_b = find(ann_b_matches, lambda x: x == ann_a) - test.assertEqual(ann_a, ann_b, 'ann: %s' % ann_to_str(ann_a)) + test.assertEqual(ann_a, ann_b, 'ann %s, candidates %s' % (ann_a, ann_b_matches)) item_b.annotations.remove(ann_b) # avoid repeats def compare_datasets_strict(test, expected, actual): @@ -115,4 +101,4 @@ def compare_datasets_strict(test, expected, actual): for idx, (item_a, item_b) in enumerate(zip(e_subset, a_subset)): test.assertEqual(item_a, item_b, '%s:\n%s\nvs.\n%s\n' % \ - (idx, item_to_str(item_a), item_to_str(item_b))) \ No newline at end of file + (idx, item_a, item_b)) \ No newline at end of file diff --git a/datumaro/requirements.txt b/datumaro/requirements.txt index c75978fc..ce583783 100644 --- a/datumaro/requirements.txt +++ b/datumaro/requirements.txt @@ -1,3 +1,4 @@ +attrs>=19.3.0 Cython>=0.27.3 # include before pycocotools defusedxml>=0.6.0 GitPython>=3.0.8 diff --git a/datumaro/setup.py b/datumaro/setup.py index 90c39ce4..4ebf1119 100644 --- a/datumaro/setup.py +++ b/datumaro/setup.py @@ -48,6 +48,7 @@ setuptools.setup( ], python_requires='>=3.5', install_requires=[ + 'attrs', 'defusedxml', 'GitPython', 'lxml', diff --git a/datumaro/tests/test_datumaro_format.py b/datumaro/tests/test_datumaro_format.py index 421627a8..8faf5ef0 100644 --- a/datumaro/tests/test_datumaro_format.py +++ b/datumaro/tests/test_datumaro_format.py @@ -65,7 +65,7 @@ class DatumaroConverterTest(TestCase): annotations=[ Caption('test'), Label(2), - Bbox(1, 2, 3, 4, 5, id=42, group=42) + Bbox(1, 2, 3, 4, label=5, id=42, group=42) ]), DatasetItem(id=2, subset='val', diff --git a/datumaro/tests/test_project.py b/datumaro/tests/test_project.py index 023341c6..ed4ad976 100644 --- a/datumaro/tests/test_project.py +++ b/datumaro/tests/test_project.py @@ -133,15 +133,15 @@ class ProjectTest(TestCase): self.assertTrue('project1' in dataset.sources) def test_can_batch_launch_custom_model(self): - class TestExtractor(Extractor): - def __iter__(self): - for i in range(5): - yield DatasetItem(id=i, subset='train', image=np.array([i])) + dataset = Dataset.from_iterable([ + DatasetItem(id=i, subset='train', image=np.array([i])) + for i in range(5) + ], categories=['label']) class TestLauncher(Launcher): def launch(self, inputs): for i, inp in enumerate(inputs): - yield [ Label(attributes={'idx': i, 'data': inp.item()}) ] + yield [ Label(0, attributes={'idx': i, 'data': inp.item()}) ] model_name = 'model' launcher_name = 'custom_launcher' @@ -150,10 +150,9 @@ class ProjectTest(TestCase): project.env.launchers.register(launcher_name, TestLauncher) project.add_model(model_name, { 'launcher': launcher_name }) model = project.make_executable_model(model_name) - extractor = TestExtractor() batch_size = 3 - executor = ModelTransform(extractor, model, batch_size=batch_size) + executor = ModelTransform(dataset, model, batch_size=batch_size) for item in executor: self.assertEqual(1, len(item.annotations)) diff --git a/datumaro/tests/test_transforms.py b/datumaro/tests/test_transforms.py index a55a446e..ed072a67 100644 --- a/datumaro/tests/test_transforms.py +++ b/datumaro/tests/test_transforms.py @@ -356,86 +356,58 @@ class TransformsTest(TestCase): ]) def test_remap_labels(self): - class SrcExtractor(Extractor): - def __iter__(self): - return iter([ - DatasetItem(id=1, annotations=[ - # Should be remapped - Label(1), - Bbox(1, 2, 3, 4, label=2), - Mask(image=np.array([1]), label=3), - - # Should be kept - Polygon([1, 1, 2, 2, 3, 4], label=4), - PolyLine([1, 3, 4, 2, 5, 6], label=None) - ]), - ]) - - def categories(self): - label_cat = LabelCategories() - label_cat.add('label0') - label_cat.add('label1') - label_cat.add('label2') - label_cat.add('label3') - label_cat.add('label4') - - mask_cat = MaskCategories( - colormap=mask_tools.generate_colormap(5)) - - return { - AnnotationType.label: label_cat, - AnnotationType.mask: mask_cat, - } - - class DstExtractor(Extractor): - def __iter__(self): - return iter([ - DatasetItem(id=1, annotations=[ - Label(1), - Bbox(1, 2, 3, 4, label=0), - Mask(image=np.array([1]), label=1), - - Polygon([1, 1, 2, 2, 3, 4], label=2), - PolyLine([1, 3, 4, 2, 5, 6], label=None) - ]), - ]) - - def categories(self): - label_cat = LabelCategories() - label_cat.add('label0') - label_cat.add('label9') - label_cat.add('label4') + src_dataset = Dataset.from_iterable([ + DatasetItem(id=1, annotations=[ + # Should be remapped + Label(1), + Bbox(1, 2, 3, 4, label=2), + Mask(image=np.array([1]), label=3), + + # Should be kept + Polygon([1, 1, 2, 2, 3, 4], label=4), + PolyLine([1, 3, 4, 2, 5, 6]) + ]) + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + 'label%s' % i for i in range(5)), + AnnotationType.mask: MaskCategories( + colormap=mask_tools.generate_colormap(5)), + }) - mask_cat = MaskCategories(colormap={ - k: v for k, v in mask_tools.generate_colormap(5).items() - if k in { 0, 1, 3, 4 } - }) + dst_dataset = Dataset.from_iterable([ + DatasetItem(id=1, annotations=[ + Label(1), + Bbox(1, 2, 3, 4, label=0), + Mask(image=np.array([1]), label=1), - return { - AnnotationType.label: label_cat, - AnnotationType.mask: mask_cat, - } + Polygon([1, 1, 2, 2, 3, 4], label=2), + PolyLine([1, 3, 4, 2, 5, 6], label=None) + ]), + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + ['label0', 'label9', 'label4']), + AnnotationType.mask: MaskCategories(colormap={ + k: v for k, v in mask_tools.generate_colormap(5).items() + if k in { 0, 1, 3, 4 } + }) + }) - actual = transforms.RemapLabels(SrcExtractor(), mapping={ + actual = transforms.RemapLabels(src_dataset, mapping={ 'label1': 'label9', 'label2': 'label0', 'label3': 'label9', }, default='keep') - compare_datasets(self, DstExtractor(), actual) + compare_datasets(self, dst_dataset, actual) def test_remap_labels_delete_unspecified(self): source_dataset = Dataset.from_iterable([ DatasetItem(id=1, annotations=[ Label(0) ]) - ], categories={ - AnnotationType.label: LabelCategories.from_iterable('label0'), - }) + ], categories=['label0']) target_dataset = Dataset.from_iterable([ - DatasetItem(id=1, annotations=[]), - ], categories={ - AnnotationType.label: LabelCategories(), - }) + DatasetItem(id=1), + ], categories=[]) actual = transforms.RemapLabels(source_dataset, mapping={}, default='delete')