[Datumaro] Use attrs for model declaration (#2000)

* bandit fixes

* t

* refactor extractor to use attrs

* Merge branch 'develop' into zm/extractor-with-attrs

* update requirements

* restore changelog

* fix bindings
main
zhiltsov-max 6 years ago committed by GitHub
parent eaeb67dc1d
commit 000e9c4dad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -531,7 +531,7 @@ class CvatTaskDataExtractor(datumaro.SourceExtractor):
return dm_attr
for tag_obj in cvat_frame_anno.tags:
anno_group = tag_obj.group
anno_group = tag_obj.group or 0
anno_label = map_label(tag_obj.label)
anno_attr = convert_attrs(tag_obj.label, tag_obj.attributes)
@ -540,7 +540,7 @@ class CvatTaskDataExtractor(datumaro.SourceExtractor):
item_anno.append(anno)
for shape_obj in cvat_frame_anno.labeled_shapes:
anno_group = shape_obj.group
anno_group = shape_obj.group or 0
anno_label = map_label(shape_obj.label)
anno_attr = convert_attrs(shape_obj.label, shape_obj.attributes)
anno_attr['occluded'] = shape_obj.occluded

@ -4,7 +4,7 @@
# SPDX-License-Identifier: MIT
import logging as log
from lxml import etree as ET # NOTE: lxml has proper XPath implementation
from lxml import etree as ET # lxml has proper XPath implementation
from datumaro.components.extractor import (Transform,
Annotation, AnnotationType,
Label, Mask, Points, Polygon, PolyLine, Bbox, Caption,

@ -7,7 +7,11 @@ from collections import namedtuple
from enum import Enum
import numpy as np
import attr
from attr import attrs, attrib
from datumaro.util.image import Image
from datumaro.util.attrs_util import not_empty, default_if_none
AnnotationType = Enum('AnnotationType',
[
@ -22,59 +26,34 @@ AnnotationType = Enum('AnnotationType',
_COORDINATE_ROUNDING_DIGITS = 2
@attrs
class Annotation:
# pylint: disable=redefined-builtin
def __init__(self, id=None, type=None, attributes=None, group=None):
if id is not None:
id = int(id)
self.id = id
id = attrib(converter=int, default=0, kw_only=True)
attributes = attrib(converter=dict, factory=dict, kw_only=True)
group = attrib(converter=int, default=0, kw_only=True)
assert type in AnnotationType
self.type = type
if attributes is None:
attributes = {}
else:
attributes = dict(attributes)
self.attributes = attributes
def __attrs_post_init__(self):
assert isinstance(self.type, AnnotationType)
if group is None:
group = 0
else:
group = int(group)
self.group = group
# pylint: enable=redefined-builtin
@property
def type(self):
return self._type # must be set in subclasses
def __eq__(self, other):
if not isinstance(other, Annotation):
return False
return \
(self.id == other.id) and \
(self.type == other.type) and \
(self.attributes == other.attributes) and \
(self.group == other.group)
def wrap(item, **kwargs):
return attr.evolve(item, **kwargs)
@attrs
class Categories:
def __init__(self, attributes=None):
if attributes is None:
attributes = set()
else:
if not isinstance(attributes, set):
attributes = set(attributes)
for attr in attributes:
assert isinstance(attr, str)
self.attributes = attributes
def __eq__(self, other):
if not isinstance(other, Categories):
return False
return \
(self.attributes == other.attributes)
attributes = attrib(factory=set, validator=default_if_none(set),
kw_only=True)
@attrs
class LabelCategories(Categories):
Category = namedtuple('Category', ['name', 'parent', 'attributes'])
items = attrib(factory=list, validator=default_if_none(list))
_indices = attrib(factory=dict, init=False, eq=False)
@classmethod
def from_iterable(cls, iterable):
"""Generation of LabelCategories from iterable object
@ -102,14 +81,7 @@ class LabelCategories(Categories):
return temp_categories
def __init__(self, items=None, attributes=None):
super().__init__(attributes=attributes)
if items is None:
items = []
self.items = items
self._indices = {}
def __attrs_post_init__(self):
self._reindex()
def _reindex(self):
@ -142,49 +114,23 @@ class LabelCategories(Categories):
return index, self.items[index]
return index, None
def __eq__(self, other):
if not super().__eq__(other):
return False
return \
(self.items == other.items)
@attrs
class Label(Annotation):
# pylint: disable=redefined-builtin
def __init__(self, label=None,
id=None, attributes=None, group=None):
super().__init__(id=id, type=AnnotationType.label,
attributes=attributes, group=group)
if label is not None:
label = int(label)
self.label = label
# pylint: enable=redefined-builtin
def __eq__(self, other):
if not super().__eq__(other):
return False
return \
(self.label == other.label)
_type = AnnotationType.label
label = attrib(converter=int)
@attrs(eq=False)
class MaskCategories(Categories):
def __init__(self, colormap=None, inverse_colormap=None, attributes=None):
super().__init__(attributes=attributes)
# colormap: label id -> color
if colormap is None:
colormap = {}
self.colormap = colormap
self._inverse_colormap = inverse_colormap
colormap = attrib(factory=dict, validator=default_if_none(dict))
_inverse_colormap = attrib(default=None,
validator=attr.validators.optional(dict))
@property
def inverse_colormap(self):
from datumaro.util.mask_tools import invert_colormap
if self._inverse_colormap is None:
if self.colormap is not None:
try:
self._inverse_colormap = invert_colormap(self.colormap)
except Exception:
pass
self._inverse_colormap = invert_colormap(self.colormap)
return self._inverse_colormap
def __eq__(self, other):
@ -196,25 +142,13 @@ class MaskCategories(Categories):
return False
return True
@attrs(eq=False)
class Mask(Annotation):
# pylint: disable=redefined-builtin
def __init__(self, image=None, label=None, z_order=None,
id=None, attributes=None, group=None):
super().__init__(type=AnnotationType.mask,
id=id, attributes=attributes, group=group)
self._image = image
if label is not None:
label = int(label)
self._label = label
if z_order is None:
z_order = 0
else:
z_order = int(z_order)
self._z_order = z_order
# pylint: enable=redefined-builtin
_type = AnnotationType.mask
_image = attrib()
label = attrib(converter=attr.converters.optional(int),
default=None, kw_only=True)
z_order = attrib(converter=int, default=0, kw_only=True)
@property
def image(self):
@ -222,14 +156,6 @@ class Mask(Annotation):
return self._image()
return self._image
@property
def label(self):
return self._label
@property
def z_order(self):
return self._z_order
def as_class_mask(self, label_id=None):
if label_id is None:
label_id = self.label
@ -255,19 +181,14 @@ class Mask(Annotation):
return \
(self.label == other.label) and \
(self.z_order == other.z_order) and \
(self.image is not None and other.image is not None and \
np.array_equal(self.image, other.image))
(np.array_equal(self.image, other.image))
@attrs(eq=False)
class RleMask(Mask):
# pylint: disable=redefined-builtin
def __init__(self, rle=None, label=None, z_order=None,
id=None, attributes=None, group=None):
lazy_decode = self._lazy_decode(rle)
super().__init__(image=lazy_decode, label=label, z_order=z_order,
id=id, attributes=attributes, group=group)
self._rle = rle
# pylint: enable=redefined-builtin
rle = attrib()
_image = attrib(default=attr.Factory(
lambda self: self._lazy_decode(self.rle),
takes_self=True), init=False)
@staticmethod
def _lazy_decode(rle):
@ -276,20 +197,16 @@ class RleMask(Mask):
def get_area(self):
from pycocotools import mask as mask_utils
return mask_utils.area(self._rle)
return mask_utils.area(self.rle)
def get_bbox(self):
from pycocotools import mask as mask_utils
return mask_utils.toBbox(self._rle)
@property
def rle(self):
return self._rle
return mask_utils.toBbox(self.rle)
def __eq__(self, other):
if not isinstance(other, __class__):
return super().__eq__(other)
return self._rle == other._rle
return self.rle == other.rle
class CompiledMask:
@staticmethod
@ -375,38 +292,13 @@ def compute_iou(bbox_a, bbox_b):
return intersection / max(1.0, union)
@attrs
class _Shape(Annotation):
# pylint: disable=redefined-builtin
def __init__(self, type, points=None, label=None, z_order=None,
id=None, attributes=None, group=None):
super().__init__(id=id, type=type,
attributes=attributes, group=group)
if points is not None:
points = [round(p, _COORDINATE_ROUNDING_DIGITS) for p in points]
self._points = points
if label is not None:
label = int(label)
self._label = label
if z_order is None:
z_order = 0
else:
z_order = int(z_order)
self._z_order = z_order
# pylint: enable=redefined-builtin
@property
def points(self):
return self._points
@property
def label(self):
return self._label
@property
def z_order(self):
return self._z_order
points = attrib(converter=lambda x:
[round(p, _COORDINATE_ROUNDING_DIGITS) for p in x])
label = attrib(converter=attr.converters.optional(int),
default=None, kw_only=True)
z_order = attrib(converter=int, default=0, kw_only=True)
def get_area(self):
raise NotImplementedError()
@ -424,22 +316,9 @@ class _Shape(Annotation):
y1 = max(ys)
return [x0, y0, x1 - x0, y1 - y0]
def __eq__(self, other):
if not super().__eq__(other):
return False
return \
(np.array_equal(self.points, other.points)) and \
(self.z_order == other.z_order) and \
(self.label == other.label)
@attrs
class PolyLine(_Shape):
# pylint: disable=redefined-builtin
def __init__(self, points=None, label=None, z_order=None,
id=None, attributes=None, group=None):
super().__init__(type=AnnotationType.polyline,
points=points, label=label, z_order=z_order,
id=id, attributes=attributes, group=group)
# pylint: enable=redefined-builtin
_type = AnnotationType.polyline
def as_polygon(self):
return self.points[:]
@ -447,18 +326,14 @@ class PolyLine(_Shape):
def get_area(self):
return 0
@attrs
class Polygon(_Shape):
# pylint: disable=redefined-builtin
def __init__(self, points=None, label=None,
z_order=None, id=None, attributes=None, group=None):
if points is not None:
# keep the message on the single line to produce
# informative output
assert len(points) % 2 == 0 and 3 <= len(points) // 2, "Wrong polygon points: %s" % points
super().__init__(type=AnnotationType.polygon,
points=points, label=label, z_order=z_order,
id=id, attributes=attributes, group=group)
# pylint: enable=redefined-builtin
_type = AnnotationType.polygon
def __attrs_post_init__(self):
super().__attrs_post_init__()
# keep the message on a single line to produce informative output
assert len(self.points) % 2 == 0 and 3 <= len(self.points) // 2, "Wrong polygon points: %s" % self.points
def get_area(self):
import pycocotools.mask as mask_utils
@ -468,14 +343,16 @@ class Polygon(_Shape):
area = mask_utils.area(rle)[0]
return area
@attrs
class Bbox(_Shape):
# pylint: disable=redefined-builtin
def __init__(self, x=0, y=0, w=0, h=0, label=None, z_order=None,
id=None, attributes=None, group=None):
super().__init__(type=AnnotationType.bbox,
points=[x, y, x + w, y + h], label=label, z_order=z_order,
id=id, attributes=attributes, group=group)
# pylint: enable=redefined-builtin
_type = AnnotationType.bbox
# will be overridden by attrs, then will be overridden again by us
# attrs' method will be renamed to __attrs_init__
def __init__(self, x, y, w, h, *args, **kwargs):
kwargs.pop('points', None) # comes from wrap()
self.__attrs_init__([x, y, x + w, y + h], *args, **kwargs)
__actual_init__ = __init__ # save pointer
@property
def x(self):
@ -511,9 +388,21 @@ class Bbox(_Shape):
def iou(self, other):
return compute_iou(self.get_bbox(), other.get_bbox())
def wrap(item, **kwargs):
d = {'x': item.x, 'y': item.y, 'w': item.w, 'h': item.h}
d.update(kwargs)
return attr.evolve(item, **d)
assert not hasattr(Bbox, '__attrs_init__') # hopefully, it will be supported
setattr(Bbox, '__attrs_init__', Bbox.__init__)
setattr(Bbox, '__init__', Bbox.__actual_init__)
@attrs
class PointsCategories(Categories):
Category = namedtuple('Category', ['labels', 'joints'])
items = attrib(factory=dict, validator=default_if_none(dict))
@classmethod
def from_iterable(cls, iterable):
"""Generation of PointsCategories from iterable object
@ -539,13 +428,6 @@ class PointsCategories(Categories):
temp_categories.add(*category)
return temp_categories
def __init__(self, items=None, attributes=None):
super().__init__(attributes=attributes)
if items is None:
items = {}
self.items = items
def add(self, label_id, labels=None, joints=None):
if labels is None:
labels = []
@ -554,41 +436,30 @@ class PointsCategories(Categories):
joints = set(map(tuple, joints))
self.items[label_id] = self.Category(labels, joints)
def __eq__(self, other):
if not super().__eq__(other):
return False
return \
(self.items == other.items)
@attrs
class Points(_Shape):
Visibility = Enum('Visibility', [
('absent', 0),
('hidden', 1),
('visible', 2),
])
_type = AnnotationType.points
# pylint: disable=redefined-builtin
def __init__(self, points=None, visibility=None, label=None, z_order=None,
id=None, attributes=None, group=None):
if points is not None:
assert len(points) % 2 == 0
if visibility is not None:
assert len(visibility) == len(points) // 2
for i, v in enumerate(visibility):
if not isinstance(v, self.Visibility):
visibility[i] = self.Visibility(v)
else:
visibility = []
for _ in range(len(points) // 2):
visibility.append(self.Visibility.visible)
super().__init__(type=AnnotationType.points,
points=points, label=label, z_order=z_order,
id=id, attributes=attributes, group=group)
visibility = attrib(type=list, default=None)
@visibility.validator
def _visibility_validator(self, attribute, visibility):
if visibility is None:
visibility = [self.Visibility.visible] * (len(self.points) // 2)
else:
for i, v in enumerate(visibility):
if not isinstance(v, self.Visibility):
visibility[i] = self.Visibility(v)
assert len(visibility) == len(self.points) // 2
self.visibility = visibility
# pylint: enable=redefined-builtin
def __attrs_post_init__(self):
super().__attrs_post_init__()
assert len(self.points) % 2 == 0, self.points
def get_area(self):
return 0
@ -604,117 +475,37 @@ class Points(_Shape):
y1 = max(ys, default=0)
return [x0, y0, x1 - x0, y1 - y0]
def __eq__(self, other):
if not super().__eq__(other):
return False
return \
(self.visibility == other.visibility)
@attrs
class Caption(Annotation):
# pylint: disable=redefined-builtin
def __init__(self, caption=None,
id=None, attributes=None, group=None):
super().__init__(id=id, type=AnnotationType.caption,
attributes=attributes, group=group)
if caption is None:
caption = ''
else:
caption = str(caption)
self.caption = caption
# pylint: enable=redefined-builtin
def __eq__(self, other):
if not super().__eq__(other):
return False
return \
(self.caption == other.caption)
_type = AnnotationType.caption
caption = attrib(converter=str)
@attrs
class DatasetItem:
# pylint: disable=redefined-builtin
def __init__(self, id=None, annotations=None,
subset=None, path=None, image=None, attributes=None):
assert id is not None
self._id = str(id).replace('\\', '/')
if subset is None:
subset = ''
else:
subset = str(subset)
self._subset = subset
if path is None:
path = []
else:
path = list(path)
self._path = path
if annotations is None:
annotations = []
else:
annotations = list(annotations)
self._annotations = annotations
id = attrib(converter=lambda x: str(x).replace('\\', '/'),
type=str, validator=not_empty)
annotations = attrib(factory=list, validator=default_if_none(list))
subset = attrib(default='', validator=default_if_none(str))
path = attrib(factory=list, validator=default_if_none(list))
image = attrib(type=Image, default=None)
@image.validator
def _image_validator(self, attribute, image):
if callable(image) or isinstance(image, np.ndarray):
image = Image(data=image)
elif isinstance(image, str):
image = Image(path=image)
assert image is None or isinstance(image, Image)
self._image = image
self.image = image
if attributes is None:
attributes = {}
else:
attributes = dict(attributes)
self._attributes = attributes
# pylint: enable=redefined-builtin
@property
def id(self):
return self._id
@property
def subset(self):
return self._subset
@property
def path(self):
return self._path
@property
def annotations(self):
return self._annotations
@property
def image(self):
return self._image
attributes = attrib(factory=dict, validator=default_if_none(dict))
@property
def has_image(self):
return self._image is not None
@property
def attributes(self):
return self._attributes
def __eq__(self, other):
if not isinstance(other, __class__):
return False
return \
(self.id == other.id) and \
(self.subset == other.subset) and \
(self.path == other.path) and \
(self.annotations == other.annotations) and \
(self.image == other.image) and \
(self.attributes == other.attributes)
return self.image is not None
def wrap(item, **kwargs):
expected_args = {'id', 'annotations', 'subset',
'path', 'image', 'attributes'}
for k in expected_args:
if k not in kwargs:
kwargs[k] = getattr(item, k)
return DatasetItem(**kwargs)
return attr.evolve(item, **kwargs)
class IExtractor:
def __iter__(self):

@ -102,7 +102,7 @@ class CvatExtractor(SourceExtractor):
else:
try:
attr_value = float(attr_value)
except Exception:
except ValueError:
pass
attributes[el.attrib['name']] = attr_value
elif el.tag in cls._SUPPORTED_SHAPES:
@ -251,7 +251,7 @@ class CvatExtractor(SourceExtractor):
@classmethod
def _parse_shape_ann(cls, ann, categories):
ann_id = ann.get('id')
ann_id = ann.get('id', 0)
ann_type = ann['type']
attributes = ann.get('attributes') or {}

@ -83,7 +83,7 @@ class LabelMeExtractor(SourceExtractor):
else:
try:
value = float(value)
except Exception:
except ValueError:
pass
parsed.append((name, value))
else:

@ -504,7 +504,6 @@ class RemapLabels(Transform, CliPlugin):
return self._categories
def transform_item(self, item):
# TODO: provide non-inplace version
annotations = []
for ann in item.annotations:
if ann.type in { AnnotationType.label, AnnotationType.mask,
@ -513,9 +512,7 @@ class RemapLabels(Transform, CliPlugin):
} and ann.label is not None:
conv_label = self._map_id(ann.label)
if conv_label is not None:
ann._label = conv_label
annotations.append(ann)
annotations.append(ann.wrap(label=conv_label))
else:
annotations.append(ann)
item._annotations = annotations
return item
annotations.append(ann.wrap())
return item.wrap(annotations=annotations)

@ -0,0 +1,26 @@
# Copyright (C) 2020 Intel Corporation
#
# SPDX-License-Identifier: MIT
import attr
_NOTSET = object()
def not_empty(inst, attribute, x):
assert len(x) != 0, x
def default_if_none(conv):
def validator(inst, attribute, value):
default = attribute.default
if value is None:
if callable(default):
value = default()
elif isinstance(default, attr.Factory):
value = default.factory()
else:
value = default
elif not isinstance(value, attribute.type or conv):
value = conv(value)
setattr(inst, attribute.name, value)
return validator

@ -43,19 +43,6 @@ class TestDir(FileRemover):
super().__init__(path, is_dir=True, ignore_errors=ignore_errors)
def ann_to_str(ann):
return vars(ann)
def item_to_str(item):
return '\n'.join(
[
'%s' % vars(item)
] + [
'ann[%s]: %s' % (i, ann_to_str(a))
for i, a in enumerate(item.annotations)
]
)
def compare_categories(test, expected, actual):
test.assertEqual(
sorted(expected, key=lambda t: t.value),
@ -92,12 +79,11 @@ def compare_datasets(test, expected, actual):
for ann_a in item_a.annotations:
# We might find few corresponding items, so check them all
ann_b_matches = [x for x in item_b.annotations
if x.id == ann_a.id and \
x.type == ann_a.type and x.group == ann_a.group]
if x.type == ann_a.type]
test.assertFalse(len(ann_b_matches) == 0, 'ann id: %s' % ann_a.id)
ann_b = find(ann_b_matches, lambda x: x == ann_a)
test.assertEqual(ann_a, ann_b, 'ann: %s' % ann_to_str(ann_a))
test.assertEqual(ann_a, ann_b, 'ann %s, candidates %s' % (ann_a, ann_b_matches))
item_b.annotations.remove(ann_b) # avoid repeats
def compare_datasets_strict(test, expected, actual):
@ -115,4 +101,4 @@ def compare_datasets_strict(test, expected, actual):
for idx, (item_a, item_b) in enumerate(zip(e_subset, a_subset)):
test.assertEqual(item_a, item_b,
'%s:\n%s\nvs.\n%s\n' % \
(idx, item_to_str(item_a), item_to_str(item_b)))
(idx, item_a, item_b))

@ -1,3 +1,4 @@
attrs>=19.3.0
Cython>=0.27.3 # include before pycocotools
defusedxml>=0.6.0
GitPython>=3.0.8

@ -48,6 +48,7 @@ setuptools.setup(
],
python_requires='>=3.5',
install_requires=[
'attrs',
'defusedxml',
'GitPython',
'lxml',

@ -65,7 +65,7 @@ class DatumaroConverterTest(TestCase):
annotations=[
Caption('test'),
Label(2),
Bbox(1, 2, 3, 4, 5, id=42, group=42)
Bbox(1, 2, 3, 4, label=5, id=42, group=42)
]),
DatasetItem(id=2, subset='val',

@ -133,15 +133,15 @@ class ProjectTest(TestCase):
self.assertTrue('project1' in dataset.sources)
def test_can_batch_launch_custom_model(self):
class TestExtractor(Extractor):
def __iter__(self):
for i in range(5):
yield DatasetItem(id=i, subset='train', image=np.array([i]))
dataset = Dataset.from_iterable([
DatasetItem(id=i, subset='train', image=np.array([i]))
for i in range(5)
], categories=['label'])
class TestLauncher(Launcher):
def launch(self, inputs):
for i, inp in enumerate(inputs):
yield [ Label(attributes={'idx': i, 'data': inp.item()}) ]
yield [ Label(0, attributes={'idx': i, 'data': inp.item()}) ]
model_name = 'model'
launcher_name = 'custom_launcher'
@ -150,10 +150,9 @@ class ProjectTest(TestCase):
project.env.launchers.register(launcher_name, TestLauncher)
project.add_model(model_name, { 'launcher': launcher_name })
model = project.make_executable_model(model_name)
extractor = TestExtractor()
batch_size = 3
executor = ModelTransform(extractor, model, batch_size=batch_size)
executor = ModelTransform(dataset, model, batch_size=batch_size)
for item in executor:
self.assertEqual(1, len(item.annotations))

@ -356,86 +356,58 @@ class TransformsTest(TestCase):
])
def test_remap_labels(self):
class SrcExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=1, annotations=[
# Should be remapped
Label(1),
Bbox(1, 2, 3, 4, label=2),
Mask(image=np.array([1]), label=3),
# Should be kept
Polygon([1, 1, 2, 2, 3, 4], label=4),
PolyLine([1, 3, 4, 2, 5, 6], label=None)
]),
])
def categories(self):
label_cat = LabelCategories()
label_cat.add('label0')
label_cat.add('label1')
label_cat.add('label2')
label_cat.add('label3')
label_cat.add('label4')
mask_cat = MaskCategories(
colormap=mask_tools.generate_colormap(5))
return {
AnnotationType.label: label_cat,
AnnotationType.mask: mask_cat,
}
class DstExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=1, annotations=[
Label(1),
Bbox(1, 2, 3, 4, label=0),
Mask(image=np.array([1]), label=1),
Polygon([1, 1, 2, 2, 3, 4], label=2),
PolyLine([1, 3, 4, 2, 5, 6], label=None)
]),
])
def categories(self):
label_cat = LabelCategories()
label_cat.add('label0')
label_cat.add('label9')
label_cat.add('label4')
src_dataset = Dataset.from_iterable([
DatasetItem(id=1, annotations=[
# Should be remapped
Label(1),
Bbox(1, 2, 3, 4, label=2),
Mask(image=np.array([1]), label=3),
# Should be kept
Polygon([1, 1, 2, 2, 3, 4], label=4),
PolyLine([1, 3, 4, 2, 5, 6])
])
], categories={
AnnotationType.label: LabelCategories.from_iterable(
'label%s' % i for i in range(5)),
AnnotationType.mask: MaskCategories(
colormap=mask_tools.generate_colormap(5)),
})
mask_cat = MaskCategories(colormap={
k: v for k, v in mask_tools.generate_colormap(5).items()
if k in { 0, 1, 3, 4 }
})
dst_dataset = Dataset.from_iterable([
DatasetItem(id=1, annotations=[
Label(1),
Bbox(1, 2, 3, 4, label=0),
Mask(image=np.array([1]), label=1),
return {
AnnotationType.label: label_cat,
AnnotationType.mask: mask_cat,
}
Polygon([1, 1, 2, 2, 3, 4], label=2),
PolyLine([1, 3, 4, 2, 5, 6], label=None)
]),
], categories={
AnnotationType.label: LabelCategories.from_iterable(
['label0', 'label9', 'label4']),
AnnotationType.mask: MaskCategories(colormap={
k: v for k, v in mask_tools.generate_colormap(5).items()
if k in { 0, 1, 3, 4 }
})
})
actual = transforms.RemapLabels(SrcExtractor(), mapping={
actual = transforms.RemapLabels(src_dataset, mapping={
'label1': 'label9',
'label2': 'label0',
'label3': 'label9',
}, default='keep')
compare_datasets(self, DstExtractor(), actual)
compare_datasets(self, dst_dataset, actual)
def test_remap_labels_delete_unspecified(self):
source_dataset = Dataset.from_iterable([
DatasetItem(id=1, annotations=[ Label(0) ])
], categories={
AnnotationType.label: LabelCategories.from_iterable('label0'),
})
], categories=['label0'])
target_dataset = Dataset.from_iterable([
DatasetItem(id=1, annotations=[]),
], categories={
AnnotationType.label: LabelCategories(),
})
DatasetItem(id=1),
], categories=[])
actual = transforms.RemapLabels(source_dataset,
mapping={}, default='delete')

Loading…
Cancel
Save