[Datumaro] Use attrs for model declaration (#2000)

* bandit fixes

* t

* refactor extractor to use attrs

* Merge branch 'develop' into zm/extractor-with-attrs

* update requirements

* restore changelog

* fix bindings
main
zhiltsov-max 6 years ago committed by GitHub
parent eaeb67dc1d
commit 000e9c4dad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -531,7 +531,7 @@ class CvatTaskDataExtractor(datumaro.SourceExtractor):
return dm_attr return dm_attr
for tag_obj in cvat_frame_anno.tags: for tag_obj in cvat_frame_anno.tags:
anno_group = tag_obj.group anno_group = tag_obj.group or 0
anno_label = map_label(tag_obj.label) anno_label = map_label(tag_obj.label)
anno_attr = convert_attrs(tag_obj.label, tag_obj.attributes) anno_attr = convert_attrs(tag_obj.label, tag_obj.attributes)
@ -540,7 +540,7 @@ class CvatTaskDataExtractor(datumaro.SourceExtractor):
item_anno.append(anno) item_anno.append(anno)
for shape_obj in cvat_frame_anno.labeled_shapes: for shape_obj in cvat_frame_anno.labeled_shapes:
anno_group = shape_obj.group anno_group = shape_obj.group or 0
anno_label = map_label(shape_obj.label) anno_label = map_label(shape_obj.label)
anno_attr = convert_attrs(shape_obj.label, shape_obj.attributes) anno_attr = convert_attrs(shape_obj.label, shape_obj.attributes)
anno_attr['occluded'] = shape_obj.occluded anno_attr['occluded'] = shape_obj.occluded

@ -4,7 +4,7 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import logging as log import logging as log
from lxml import etree as ET # NOTE: lxml has proper XPath implementation from lxml import etree as ET # lxml has proper XPath implementation
from datumaro.components.extractor import (Transform, from datumaro.components.extractor import (Transform,
Annotation, AnnotationType, Annotation, AnnotationType,
Label, Mask, Points, Polygon, PolyLine, Bbox, Caption, Label, Mask, Points, Polygon, PolyLine, Bbox, Caption,

@ -7,7 +7,11 @@ from collections import namedtuple
from enum import Enum from enum import Enum
import numpy as np import numpy as np
import attr
from attr import attrs, attrib
from datumaro.util.image import Image from datumaro.util.image import Image
from datumaro.util.attrs_util import not_empty, default_if_none
AnnotationType = Enum('AnnotationType', AnnotationType = Enum('AnnotationType',
[ [
@ -22,59 +26,34 @@ AnnotationType = Enum('AnnotationType',
_COORDINATE_ROUNDING_DIGITS = 2 _COORDINATE_ROUNDING_DIGITS = 2
@attrs
class Annotation: class Annotation:
# pylint: disable=redefined-builtin id = attrib(converter=int, default=0, kw_only=True)
def __init__(self, id=None, type=None, attributes=None, group=None): attributes = attrib(converter=dict, factory=dict, kw_only=True)
if id is not None: group = attrib(converter=int, default=0, kw_only=True)
id = int(id)
self.id = id
assert type in AnnotationType def __attrs_post_init__(self):
self.type = type assert isinstance(self.type, AnnotationType)
if attributes is None:
attributes = {}
else:
attributes = dict(attributes)
self.attributes = attributes
if group is None: @property
group = 0 def type(self):
else: return self._type # must be set in subclasses
group = int(group)
self.group = group
# pylint: enable=redefined-builtin
def __eq__(self, other): def wrap(item, **kwargs):
if not isinstance(other, Annotation): return attr.evolve(item, **kwargs)
return False
return \
(self.id == other.id) and \
(self.type == other.type) and \
(self.attributes == other.attributes) and \
(self.group == other.group)
@attrs
class Categories: class Categories:
def __init__(self, attributes=None): attributes = attrib(factory=set, validator=default_if_none(set),
if attributes is None: kw_only=True)
attributes = set()
else:
if not isinstance(attributes, set):
attributes = set(attributes)
for attr in attributes:
assert isinstance(attr, str)
self.attributes = attributes
def __eq__(self, other):
if not isinstance(other, Categories):
return False
return \
(self.attributes == other.attributes)
@attrs
class LabelCategories(Categories): class LabelCategories(Categories):
Category = namedtuple('Category', ['name', 'parent', 'attributes']) Category = namedtuple('Category', ['name', 'parent', 'attributes'])
items = attrib(factory=list, validator=default_if_none(list))
_indices = attrib(factory=dict, init=False, eq=False)
@classmethod @classmethod
def from_iterable(cls, iterable): def from_iterable(cls, iterable):
"""Generation of LabelCategories from iterable object """Generation of LabelCategories from iterable object
@ -102,14 +81,7 @@ class LabelCategories(Categories):
return temp_categories return temp_categories
def __init__(self, items=None, attributes=None): def __attrs_post_init__(self):
super().__init__(attributes=attributes)
if items is None:
items = []
self.items = items
self._indices = {}
self._reindex() self._reindex()
def _reindex(self): def _reindex(self):
@ -142,49 +114,23 @@ class LabelCategories(Categories):
return index, self.items[index] return index, self.items[index]
return index, None return index, None
def __eq__(self, other): @attrs
if not super().__eq__(other):
return False
return \
(self.items == other.items)
class Label(Annotation): class Label(Annotation):
# pylint: disable=redefined-builtin _type = AnnotationType.label
def __init__(self, label=None, label = attrib(converter=int)
id=None, attributes=None, group=None):
super().__init__(id=id, type=AnnotationType.label,
attributes=attributes, group=group)
if label is not None:
label = int(label)
self.label = label
# pylint: enable=redefined-builtin
def __eq__(self, other):
if not super().__eq__(other):
return False
return \
(self.label == other.label)
@attrs(eq=False)
class MaskCategories(Categories): class MaskCategories(Categories):
def __init__(self, colormap=None, inverse_colormap=None, attributes=None): colormap = attrib(factory=dict, validator=default_if_none(dict))
super().__init__(attributes=attributes) _inverse_colormap = attrib(default=None,
validator=attr.validators.optional(dict))
# colormap: label id -> color
if colormap is None:
colormap = {}
self.colormap = colormap
self._inverse_colormap = inverse_colormap
@property @property
def inverse_colormap(self): def inverse_colormap(self):
from datumaro.util.mask_tools import invert_colormap from datumaro.util.mask_tools import invert_colormap
if self._inverse_colormap is None: if self._inverse_colormap is None:
if self.colormap is not None: if self.colormap is not None:
try: self._inverse_colormap = invert_colormap(self.colormap)
self._inverse_colormap = invert_colormap(self.colormap)
except Exception:
pass
return self._inverse_colormap return self._inverse_colormap
def __eq__(self, other): def __eq__(self, other):
@ -196,25 +142,13 @@ class MaskCategories(Categories):
return False return False
return True return True
@attrs(eq=False)
class Mask(Annotation): class Mask(Annotation):
# pylint: disable=redefined-builtin _type = AnnotationType.mask
def __init__(self, image=None, label=None, z_order=None, _image = attrib()
id=None, attributes=None, group=None): label = attrib(converter=attr.converters.optional(int),
super().__init__(type=AnnotationType.mask, default=None, kw_only=True)
id=id, attributes=attributes, group=group) z_order = attrib(converter=int, default=0, kw_only=True)
self._image = image
if label is not None:
label = int(label)
self._label = label
if z_order is None:
z_order = 0
else:
z_order = int(z_order)
self._z_order = z_order
# pylint: enable=redefined-builtin
@property @property
def image(self): def image(self):
@ -222,14 +156,6 @@ class Mask(Annotation):
return self._image() return self._image()
return self._image return self._image
@property
def label(self):
return self._label
@property
def z_order(self):
return self._z_order
def as_class_mask(self, label_id=None): def as_class_mask(self, label_id=None):
if label_id is None: if label_id is None:
label_id = self.label label_id = self.label
@ -255,19 +181,14 @@ class Mask(Annotation):
return \ return \
(self.label == other.label) and \ (self.label == other.label) and \
(self.z_order == other.z_order) and \ (self.z_order == other.z_order) and \
(self.image is not None and other.image is not None and \ (np.array_equal(self.image, other.image))
np.array_equal(self.image, other.image))
@attrs(eq=False)
class RleMask(Mask): class RleMask(Mask):
# pylint: disable=redefined-builtin rle = attrib()
def __init__(self, rle=None, label=None, z_order=None, _image = attrib(default=attr.Factory(
id=None, attributes=None, group=None): lambda self: self._lazy_decode(self.rle),
lazy_decode = self._lazy_decode(rle) takes_self=True), init=False)
super().__init__(image=lazy_decode, label=label, z_order=z_order,
id=id, attributes=attributes, group=group)
self._rle = rle
# pylint: enable=redefined-builtin
@staticmethod @staticmethod
def _lazy_decode(rle): def _lazy_decode(rle):
@ -276,20 +197,16 @@ class RleMask(Mask):
def get_area(self): def get_area(self):
from pycocotools import mask as mask_utils from pycocotools import mask as mask_utils
return mask_utils.area(self._rle) return mask_utils.area(self.rle)
def get_bbox(self): def get_bbox(self):
from pycocotools import mask as mask_utils from pycocotools import mask as mask_utils
return mask_utils.toBbox(self._rle) return mask_utils.toBbox(self.rle)
@property
def rle(self):
return self._rle
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, __class__): if not isinstance(other, __class__):
return super().__eq__(other) return super().__eq__(other)
return self._rle == other._rle return self.rle == other.rle
class CompiledMask: class CompiledMask:
@staticmethod @staticmethod
@ -375,38 +292,13 @@ def compute_iou(bbox_a, bbox_b):
return intersection / max(1.0, union) return intersection / max(1.0, union)
@attrs
class _Shape(Annotation): class _Shape(Annotation):
# pylint: disable=redefined-builtin points = attrib(converter=lambda x:
def __init__(self, type, points=None, label=None, z_order=None, [round(p, _COORDINATE_ROUNDING_DIGITS) for p in x])
id=None, attributes=None, group=None): label = attrib(converter=attr.converters.optional(int),
super().__init__(id=id, type=type, default=None, kw_only=True)
attributes=attributes, group=group) z_order = attrib(converter=int, default=0, kw_only=True)
if points is not None:
points = [round(p, _COORDINATE_ROUNDING_DIGITS) for p in points]
self._points = points
if label is not None:
label = int(label)
self._label = label
if z_order is None:
z_order = 0
else:
z_order = int(z_order)
self._z_order = z_order
# pylint: enable=redefined-builtin
@property
def points(self):
return self._points
@property
def label(self):
return self._label
@property
def z_order(self):
return self._z_order
def get_area(self): def get_area(self):
raise NotImplementedError() raise NotImplementedError()
@ -424,22 +316,9 @@ class _Shape(Annotation):
y1 = max(ys) y1 = max(ys)
return [x0, y0, x1 - x0, y1 - y0] return [x0, y0, x1 - x0, y1 - y0]
def __eq__(self, other): @attrs
if not super().__eq__(other):
return False
return \
(np.array_equal(self.points, other.points)) and \
(self.z_order == other.z_order) and \
(self.label == other.label)
class PolyLine(_Shape): class PolyLine(_Shape):
# pylint: disable=redefined-builtin _type = AnnotationType.polyline
def __init__(self, points=None, label=None, z_order=None,
id=None, attributes=None, group=None):
super().__init__(type=AnnotationType.polyline,
points=points, label=label, z_order=z_order,
id=id, attributes=attributes, group=group)
# pylint: enable=redefined-builtin
def as_polygon(self): def as_polygon(self):
return self.points[:] return self.points[:]
@ -447,18 +326,14 @@ class PolyLine(_Shape):
def get_area(self): def get_area(self):
return 0 return 0
@attrs
class Polygon(_Shape): class Polygon(_Shape):
# pylint: disable=redefined-builtin _type = AnnotationType.polygon
def __init__(self, points=None, label=None,
z_order=None, id=None, attributes=None, group=None): def __attrs_post_init__(self):
if points is not None: super().__attrs_post_init__()
# keep the message on the single line to produce # keep the message on a single line to produce informative output
# informative output assert len(self.points) % 2 == 0 and 3 <= len(self.points) // 2, "Wrong polygon points: %s" % self.points
assert len(points) % 2 == 0 and 3 <= len(points) // 2, "Wrong polygon points: %s" % points
super().__init__(type=AnnotationType.polygon,
points=points, label=label, z_order=z_order,
id=id, attributes=attributes, group=group)
# pylint: enable=redefined-builtin
def get_area(self): def get_area(self):
import pycocotools.mask as mask_utils import pycocotools.mask as mask_utils
@ -468,14 +343,16 @@ class Polygon(_Shape):
area = mask_utils.area(rle)[0] area = mask_utils.area(rle)[0]
return area return area
@attrs
class Bbox(_Shape): class Bbox(_Shape):
# pylint: disable=redefined-builtin _type = AnnotationType.bbox
def __init__(self, x=0, y=0, w=0, h=0, label=None, z_order=None,
id=None, attributes=None, group=None): # will be overridden by attrs, then will be overridden again by us
super().__init__(type=AnnotationType.bbox, # attrs' method will be renamed to __attrs_init__
points=[x, y, x + w, y + h], label=label, z_order=z_order, def __init__(self, x, y, w, h, *args, **kwargs):
id=id, attributes=attributes, group=group) kwargs.pop('points', None) # comes from wrap()
# pylint: enable=redefined-builtin self.__attrs_init__([x, y, x + w, y + h], *args, **kwargs)
__actual_init__ = __init__ # save pointer
@property @property
def x(self): def x(self):
@ -511,9 +388,21 @@ class Bbox(_Shape):
def iou(self, other): def iou(self, other):
return compute_iou(self.get_bbox(), other.get_bbox()) return compute_iou(self.get_bbox(), other.get_bbox())
def wrap(item, **kwargs):
d = {'x': item.x, 'y': item.y, 'w': item.w, 'h': item.h}
d.update(kwargs)
return attr.evolve(item, **d)
assert not hasattr(Bbox, '__attrs_init__') # hopefully, it will be supported
setattr(Bbox, '__attrs_init__', Bbox.__init__)
setattr(Bbox, '__init__', Bbox.__actual_init__)
@attrs
class PointsCategories(Categories): class PointsCategories(Categories):
Category = namedtuple('Category', ['labels', 'joints']) Category = namedtuple('Category', ['labels', 'joints'])
items = attrib(factory=dict, validator=default_if_none(dict))
@classmethod @classmethod
def from_iterable(cls, iterable): def from_iterable(cls, iterable):
"""Generation of PointsCategories from iterable object """Generation of PointsCategories from iterable object
@ -539,13 +428,6 @@ class PointsCategories(Categories):
temp_categories.add(*category) temp_categories.add(*category)
return temp_categories return temp_categories
def __init__(self, items=None, attributes=None):
super().__init__(attributes=attributes)
if items is None:
items = {}
self.items = items
def add(self, label_id, labels=None, joints=None): def add(self, label_id, labels=None, joints=None):
if labels is None: if labels is None:
labels = [] labels = []
@ -554,41 +436,30 @@ class PointsCategories(Categories):
joints = set(map(tuple, joints)) joints = set(map(tuple, joints))
self.items[label_id] = self.Category(labels, joints) self.items[label_id] = self.Category(labels, joints)
def __eq__(self, other): @attrs
if not super().__eq__(other):
return False
return \
(self.items == other.items)
class Points(_Shape): class Points(_Shape):
Visibility = Enum('Visibility', [ Visibility = Enum('Visibility', [
('absent', 0), ('absent', 0),
('hidden', 1), ('hidden', 1),
('visible', 2), ('visible', 2),
]) ])
_type = AnnotationType.points
# pylint: disable=redefined-builtin visibility = attrib(type=list, default=None)
def __init__(self, points=None, visibility=None, label=None, z_order=None, @visibility.validator
id=None, attributes=None, group=None): def _visibility_validator(self, attribute, visibility):
if points is not None: if visibility is None:
assert len(points) % 2 == 0 visibility = [self.Visibility.visible] * (len(self.points) // 2)
else:
if visibility is not None: for i, v in enumerate(visibility):
assert len(visibility) == len(points) // 2 if not isinstance(v, self.Visibility):
for i, v in enumerate(visibility): visibility[i] = self.Visibility(v)
if not isinstance(v, self.Visibility): assert len(visibility) == len(self.points) // 2
visibility[i] = self.Visibility(v)
else:
visibility = []
for _ in range(len(points) // 2):
visibility.append(self.Visibility.visible)
super().__init__(type=AnnotationType.points,
points=points, label=label, z_order=z_order,
id=id, attributes=attributes, group=group)
self.visibility = visibility self.visibility = visibility
# pylint: enable=redefined-builtin
def __attrs_post_init__(self):
super().__attrs_post_init__()
assert len(self.points) % 2 == 0, self.points
def get_area(self): def get_area(self):
return 0 return 0
@ -604,117 +475,37 @@ class Points(_Shape):
y1 = max(ys, default=0) y1 = max(ys, default=0)
return [x0, y0, x1 - x0, y1 - y0] return [x0, y0, x1 - x0, y1 - y0]
def __eq__(self, other): @attrs
if not super().__eq__(other):
return False
return \
(self.visibility == other.visibility)
class Caption(Annotation): class Caption(Annotation):
# pylint: disable=redefined-builtin _type = AnnotationType.caption
def __init__(self, caption=None, caption = attrib(converter=str)
id=None, attributes=None, group=None):
super().__init__(id=id, type=AnnotationType.caption,
attributes=attributes, group=group)
if caption is None:
caption = ''
else:
caption = str(caption)
self.caption = caption
# pylint: enable=redefined-builtin
def __eq__(self, other):
if not super().__eq__(other):
return False
return \
(self.caption == other.caption)
@attrs
class DatasetItem: class DatasetItem:
# pylint: disable=redefined-builtin id = attrib(converter=lambda x: str(x).replace('\\', '/'),
def __init__(self, id=None, annotations=None, type=str, validator=not_empty)
subset=None, path=None, image=None, attributes=None): annotations = attrib(factory=list, validator=default_if_none(list))
assert id is not None subset = attrib(default='', validator=default_if_none(str))
self._id = str(id).replace('\\', '/') path = attrib(factory=list, validator=default_if_none(list))
if subset is None: image = attrib(type=Image, default=None)
subset = '' @image.validator
else: def _image_validator(self, attribute, image):
subset = str(subset)
self._subset = subset
if path is None:
path = []
else:
path = list(path)
self._path = path
if annotations is None:
annotations = []
else:
annotations = list(annotations)
self._annotations = annotations
if callable(image) or isinstance(image, np.ndarray): if callable(image) or isinstance(image, np.ndarray):
image = Image(data=image) image = Image(data=image)
elif isinstance(image, str): elif isinstance(image, str):
image = Image(path=image) image = Image(path=image)
assert image is None or isinstance(image, Image) assert image is None or isinstance(image, Image)
self._image = image self.image = image
if attributes is None: attributes = attrib(factory=dict, validator=default_if_none(dict))
attributes = {}
else:
attributes = dict(attributes)
self._attributes = attributes
# pylint: enable=redefined-builtin
@property
def id(self):
return self._id
@property
def subset(self):
return self._subset
@property
def path(self):
return self._path
@property
def annotations(self):
return self._annotations
@property
def image(self):
return self._image
@property @property
def has_image(self): def has_image(self):
return self._image is not None return self.image is not None
@property
def attributes(self):
return self._attributes
def __eq__(self, other):
if not isinstance(other, __class__):
return False
return \
(self.id == other.id) and \
(self.subset == other.subset) and \
(self.path == other.path) and \
(self.annotations == other.annotations) and \
(self.image == other.image) and \
(self.attributes == other.attributes)
def wrap(item, **kwargs): def wrap(item, **kwargs):
expected_args = {'id', 'annotations', 'subset', return attr.evolve(item, **kwargs)
'path', 'image', 'attributes'}
for k in expected_args:
if k not in kwargs:
kwargs[k] = getattr(item, k)
return DatasetItem(**kwargs)
class IExtractor: class IExtractor:
def __iter__(self): def __iter__(self):

@ -102,7 +102,7 @@ class CvatExtractor(SourceExtractor):
else: else:
try: try:
attr_value = float(attr_value) attr_value = float(attr_value)
except Exception: except ValueError:
pass pass
attributes[el.attrib['name']] = attr_value attributes[el.attrib['name']] = attr_value
elif el.tag in cls._SUPPORTED_SHAPES: elif el.tag in cls._SUPPORTED_SHAPES:
@ -251,7 +251,7 @@ class CvatExtractor(SourceExtractor):
@classmethod @classmethod
def _parse_shape_ann(cls, ann, categories): def _parse_shape_ann(cls, ann, categories):
ann_id = ann.get('id') ann_id = ann.get('id', 0)
ann_type = ann['type'] ann_type = ann['type']
attributes = ann.get('attributes') or {} attributes = ann.get('attributes') or {}

@ -83,7 +83,7 @@ class LabelMeExtractor(SourceExtractor):
else: else:
try: try:
value = float(value) value = float(value)
except Exception: except ValueError:
pass pass
parsed.append((name, value)) parsed.append((name, value))
else: else:

@ -504,7 +504,6 @@ class RemapLabels(Transform, CliPlugin):
return self._categories return self._categories
def transform_item(self, item): def transform_item(self, item):
# TODO: provide non-inplace version
annotations = [] annotations = []
for ann in item.annotations: for ann in item.annotations:
if ann.type in { AnnotationType.label, AnnotationType.mask, if ann.type in { AnnotationType.label, AnnotationType.mask,
@ -513,9 +512,7 @@ class RemapLabels(Transform, CliPlugin):
} and ann.label is not None: } and ann.label is not None:
conv_label = self._map_id(ann.label) conv_label = self._map_id(ann.label)
if conv_label is not None: if conv_label is not None:
ann._label = conv_label annotations.append(ann.wrap(label=conv_label))
annotations.append(ann)
else: else:
annotations.append(ann) annotations.append(ann.wrap())
item._annotations = annotations return item.wrap(annotations=annotations)
return item

@ -0,0 +1,26 @@
# Copyright (C) 2020 Intel Corporation
#
# SPDX-License-Identifier: MIT
import attr
_NOTSET = object()
def not_empty(inst, attribute, x):
assert len(x) != 0, x
def default_if_none(conv):
def validator(inst, attribute, value):
default = attribute.default
if value is None:
if callable(default):
value = default()
elif isinstance(default, attr.Factory):
value = default.factory()
else:
value = default
elif not isinstance(value, attribute.type or conv):
value = conv(value)
setattr(inst, attribute.name, value)
return validator

@ -43,19 +43,6 @@ class TestDir(FileRemover):
super().__init__(path, is_dir=True, ignore_errors=ignore_errors) super().__init__(path, is_dir=True, ignore_errors=ignore_errors)
def ann_to_str(ann):
return vars(ann)
def item_to_str(item):
return '\n'.join(
[
'%s' % vars(item)
] + [
'ann[%s]: %s' % (i, ann_to_str(a))
for i, a in enumerate(item.annotations)
]
)
def compare_categories(test, expected, actual): def compare_categories(test, expected, actual):
test.assertEqual( test.assertEqual(
sorted(expected, key=lambda t: t.value), sorted(expected, key=lambda t: t.value),
@ -92,12 +79,11 @@ def compare_datasets(test, expected, actual):
for ann_a in item_a.annotations: for ann_a in item_a.annotations:
# We might find few corresponding items, so check them all # We might find few corresponding items, so check them all
ann_b_matches = [x for x in item_b.annotations ann_b_matches = [x for x in item_b.annotations
if x.id == ann_a.id and \ if x.type == ann_a.type]
x.type == ann_a.type and x.group == ann_a.group]
test.assertFalse(len(ann_b_matches) == 0, 'ann id: %s' % ann_a.id) test.assertFalse(len(ann_b_matches) == 0, 'ann id: %s' % ann_a.id)
ann_b = find(ann_b_matches, lambda x: x == ann_a) ann_b = find(ann_b_matches, lambda x: x == ann_a)
test.assertEqual(ann_a, ann_b, 'ann: %s' % ann_to_str(ann_a)) test.assertEqual(ann_a, ann_b, 'ann %s, candidates %s' % (ann_a, ann_b_matches))
item_b.annotations.remove(ann_b) # avoid repeats item_b.annotations.remove(ann_b) # avoid repeats
def compare_datasets_strict(test, expected, actual): def compare_datasets_strict(test, expected, actual):
@ -115,4 +101,4 @@ def compare_datasets_strict(test, expected, actual):
for idx, (item_a, item_b) in enumerate(zip(e_subset, a_subset)): for idx, (item_a, item_b) in enumerate(zip(e_subset, a_subset)):
test.assertEqual(item_a, item_b, test.assertEqual(item_a, item_b,
'%s:\n%s\nvs.\n%s\n' % \ '%s:\n%s\nvs.\n%s\n' % \
(idx, item_to_str(item_a), item_to_str(item_b))) (idx, item_a, item_b))

@ -1,3 +1,4 @@
attrs>=19.3.0
Cython>=0.27.3 # include before pycocotools Cython>=0.27.3 # include before pycocotools
defusedxml>=0.6.0 defusedxml>=0.6.0
GitPython>=3.0.8 GitPython>=3.0.8

@ -48,6 +48,7 @@ setuptools.setup(
], ],
python_requires='>=3.5', python_requires='>=3.5',
install_requires=[ install_requires=[
'attrs',
'defusedxml', 'defusedxml',
'GitPython', 'GitPython',
'lxml', 'lxml',

@ -65,7 +65,7 @@ class DatumaroConverterTest(TestCase):
annotations=[ annotations=[
Caption('test'), Caption('test'),
Label(2), Label(2),
Bbox(1, 2, 3, 4, 5, id=42, group=42) Bbox(1, 2, 3, 4, label=5, id=42, group=42)
]), ]),
DatasetItem(id=2, subset='val', DatasetItem(id=2, subset='val',

@ -133,15 +133,15 @@ class ProjectTest(TestCase):
self.assertTrue('project1' in dataset.sources) self.assertTrue('project1' in dataset.sources)
def test_can_batch_launch_custom_model(self): def test_can_batch_launch_custom_model(self):
class TestExtractor(Extractor): dataset = Dataset.from_iterable([
def __iter__(self): DatasetItem(id=i, subset='train', image=np.array([i]))
for i in range(5): for i in range(5)
yield DatasetItem(id=i, subset='train', image=np.array([i])) ], categories=['label'])
class TestLauncher(Launcher): class TestLauncher(Launcher):
def launch(self, inputs): def launch(self, inputs):
for i, inp in enumerate(inputs): for i, inp in enumerate(inputs):
yield [ Label(attributes={'idx': i, 'data': inp.item()}) ] yield [ Label(0, attributes={'idx': i, 'data': inp.item()}) ]
model_name = 'model' model_name = 'model'
launcher_name = 'custom_launcher' launcher_name = 'custom_launcher'
@ -150,10 +150,9 @@ class ProjectTest(TestCase):
project.env.launchers.register(launcher_name, TestLauncher) project.env.launchers.register(launcher_name, TestLauncher)
project.add_model(model_name, { 'launcher': launcher_name }) project.add_model(model_name, { 'launcher': launcher_name })
model = project.make_executable_model(model_name) model = project.make_executable_model(model_name)
extractor = TestExtractor()
batch_size = 3 batch_size = 3
executor = ModelTransform(extractor, model, batch_size=batch_size) executor = ModelTransform(dataset, model, batch_size=batch_size)
for item in executor: for item in executor:
self.assertEqual(1, len(item.annotations)) self.assertEqual(1, len(item.annotations))

@ -356,86 +356,58 @@ class TransformsTest(TestCase):
]) ])
def test_remap_labels(self): def test_remap_labels(self):
class SrcExtractor(Extractor): src_dataset = Dataset.from_iterable([
def __iter__(self): DatasetItem(id=1, annotations=[
return iter([ # Should be remapped
DatasetItem(id=1, annotations=[ Label(1),
# Should be remapped Bbox(1, 2, 3, 4, label=2),
Label(1), Mask(image=np.array([1]), label=3),
Bbox(1, 2, 3, 4, label=2),
Mask(image=np.array([1]), label=3), # Should be kept
Polygon([1, 1, 2, 2, 3, 4], label=4),
# Should be kept PolyLine([1, 3, 4, 2, 5, 6])
Polygon([1, 1, 2, 2, 3, 4], label=4), ])
PolyLine([1, 3, 4, 2, 5, 6], label=None) ], categories={
]), AnnotationType.label: LabelCategories.from_iterable(
]) 'label%s' % i for i in range(5)),
AnnotationType.mask: MaskCategories(
def categories(self): colormap=mask_tools.generate_colormap(5)),
label_cat = LabelCategories() })
label_cat.add('label0')
label_cat.add('label1')
label_cat.add('label2')
label_cat.add('label3')
label_cat.add('label4')
mask_cat = MaskCategories(
colormap=mask_tools.generate_colormap(5))
return {
AnnotationType.label: label_cat,
AnnotationType.mask: mask_cat,
}
class DstExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=1, annotations=[
Label(1),
Bbox(1, 2, 3, 4, label=0),
Mask(image=np.array([1]), label=1),
Polygon([1, 1, 2, 2, 3, 4], label=2),
PolyLine([1, 3, 4, 2, 5, 6], label=None)
]),
])
def categories(self):
label_cat = LabelCategories()
label_cat.add('label0')
label_cat.add('label9')
label_cat.add('label4')
mask_cat = MaskCategories(colormap={ dst_dataset = Dataset.from_iterable([
k: v for k, v in mask_tools.generate_colormap(5).items() DatasetItem(id=1, annotations=[
if k in { 0, 1, 3, 4 } Label(1),
}) Bbox(1, 2, 3, 4, label=0),
Mask(image=np.array([1]), label=1),
return { Polygon([1, 1, 2, 2, 3, 4], label=2),
AnnotationType.label: label_cat, PolyLine([1, 3, 4, 2, 5, 6], label=None)
AnnotationType.mask: mask_cat, ]),
} ], categories={
AnnotationType.label: LabelCategories.from_iterable(
['label0', 'label9', 'label4']),
AnnotationType.mask: MaskCategories(colormap={
k: v for k, v in mask_tools.generate_colormap(5).items()
if k in { 0, 1, 3, 4 }
})
})
actual = transforms.RemapLabels(SrcExtractor(), mapping={ actual = transforms.RemapLabels(src_dataset, mapping={
'label1': 'label9', 'label1': 'label9',
'label2': 'label0', 'label2': 'label0',
'label3': 'label9', 'label3': 'label9',
}, default='keep') }, default='keep')
compare_datasets(self, DstExtractor(), actual) compare_datasets(self, dst_dataset, actual)
def test_remap_labels_delete_unspecified(self): def test_remap_labels_delete_unspecified(self):
source_dataset = Dataset.from_iterable([ source_dataset = Dataset.from_iterable([
DatasetItem(id=1, annotations=[ Label(0) ]) DatasetItem(id=1, annotations=[ Label(0) ])
], categories={ ], categories=['label0'])
AnnotationType.label: LabelCategories.from_iterable('label0'),
})
target_dataset = Dataset.from_iterable([ target_dataset = Dataset.from_iterable([
DatasetItem(id=1, annotations=[]), DatasetItem(id=1),
], categories={ ], categories=[])
AnnotationType.label: LabelCategories(),
})
actual = transforms.RemapLabels(source_dataset, actual = transforms.RemapLabels(source_dataset,
mapping={}, default='delete') mapping={}, default='delete')

Loading…
Cancel
Save