[Datumaro] Merge with different categories (#2098)

* Add category merging

* Update error message

* Add category merging test

* update changelog

* Fix field access

* remove import

* Update CHANGELOG.md

Co-authored-by: Nikita Manovich <nikita.manovich@intel.com>
main
Maxim Zhiltsov 6 years ago committed by GitHub
parent 4dbfa3bfdf
commit ffb71fb7a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Ability to work with data on the fly (https://github.com/opencv/cvat/pull/2007)
- Annotation in process outline color wheel (<https://github.com/opencv/cvat/pull/2084>)
- [Datumaro] CLI command for dataset equality comparison (<https://github.com/opencv/cvat/pull/1989>)
- [Datumaro] Merging of datasets with different labels (<https://github.com/opencv/cvat/pull/2098>)
### Changed
- UI models (like DEXTR) were redesigned to be more interactive (<https://github.com/opencv/cvat/pull/2054>)

@ -49,7 +49,11 @@ class Categories:
@attrs
class LabelCategories(Categories):
Category = namedtuple('Category', ['name', 'parent', 'attributes'])
@attrs(repr_ns='LabelCategories')
class Category:
name = attrib(converter=str, validator=not_empty)
parent = attrib(default='', validator=default_if_none(str))
attributes = attrib(factory=set, validator=default_if_none(set))
items = attrib(factory=list, validator=default_if_none(list))
_indices = attrib(factory=dict, init=False, eq=False)
@ -93,15 +97,6 @@ class LabelCategories(Categories):
def add(self, name: str, parent: str = None, attributes: dict = None):
assert name not in self._indices, name
if attributes is None:
attributes = set()
else:
if not isinstance(attributes, set):
attributes = set(attributes)
for attr in attributes:
assert isinstance(attr, str)
if parent is None:
parent = ''
index = len(self.items)
self.items.append(self.Category(name, parent, attributes))
@ -386,7 +381,10 @@ setattr(Bbox, '__init__', Bbox.__actual_init__)
@attrs
class PointsCategories(Categories):
Category = namedtuple('Category', ['labels', 'joints'])
@attrs(repr_ns="PointsCategories")
class Category:
labels = attrib(factory=list, validator=default_if_none(list))
joints = attrib(factory=set, validator=default_if_none(set))
items = attrib(factory=dict, validator=default_if_none(dict))
@ -396,28 +394,19 @@ class PointsCategories(Categories):
Args:
iterable ([type]): This iterable object can be:
1)simple int - will generate one Category with int as label
2)list of int - will interpreted as list of Category labels
3)list of positional argumetns - will generate Categories
with this arguments
1) list of positional argumetns - will generate Categories
with these arguments
Returns:
PointsCategories: PointsCategories object
"""
temp_categories = cls()
if isinstance(iterable, int):
iterable = [[iterable]]
for category in iterable:
if isinstance(category, int):
category = [category]
temp_categories.add(*category)
return temp_categories
def add(self, label_id, labels=None, joints=None):
if labels is None:
labels = []
if joints is None:
joints = []
joints = set(map(tuple, joints))

@ -15,7 +15,8 @@ from attr import attrib, attrs
from unittest import TestCase
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.extractor import AnnotationType, Bbox, Label
from datumaro.components.extractor import (AnnotationType, Bbox, Label,
LabelCategories, PointsCategories, MaskCategories)
from datumaro.components.project import Dataset
from datumaro.util import find, filter_dict
from datumaro.util.attrs_util import ensure_cls, default_if_none
@ -53,7 +54,8 @@ def merge_categories(sources):
for cat_type, source_cat in source.items():
if not categories[cat_type] == source_cat:
raise NotImplementedError(
"Merging different categories is not implemented yet")
"Merging of datasets with different categories is "
"only allowed in 'merge' command.")
return categories
class MergingStrategy(CliPlugin):
@ -180,7 +182,8 @@ class IntersectMerge(MergingStrategy):
_categories = attrib(init=False) # merged categories
def __call__(self, datasets):
self._categories = merge_categories(d.categories() for d in datasets)
self._categories = self._merge_categories(
[d.categories() for d in datasets])
merged = Dataset(categories=self._categories)
self._check_groups_definition()
@ -283,6 +286,126 @@ class IntersectMerge(MergingStrategy):
return matches, item_map
def _merge_label_categories(self, sources):
same = True
common = None
for src_categories in sources:
src_cat = src_categories.get(AnnotationType.label)
if common is None:
common = src_cat
elif common != src_cat:
same = False
break
if same:
return common
dst_cat = LabelCategories()
for src_id, src_categories in enumerate(sources):
src_cat = src_categories.get(AnnotationType.label)
if src_cat is None:
continue
for src_label in src_cat.items:
dst_label = dst_cat.find(src_label.name)[1]
if dst_label is not None:
if dst_label != src_label:
if src_label.parent and dst_label.parent and \
src_label.parent != dst_label.parent:
raise ValueError("Can't merge label category "
"%s (from #%s): "
"parent label conflict: %s vs. %s" % \
(src_label.name, src_id,
src_label.parent, dst_label.parent)
)
dst_label.parent = dst_label.parent or src_label.parent
dst_label.attributes |= src_label.attributes
else:
pass
else:
dst_cat.add(src_label.name,
src_label.parent, src_label.attributes)
return dst_cat
def _merge_point_categories(self, sources, label_cat):
dst_point_cat = PointsCategories()
for src_id, src_categories in enumerate(sources):
src_label_cat = src_categories.get(AnnotationType.label)
src_point_cat = src_categories.get(AnnotationType.points)
if src_label_cat is None or src_point_cat is None:
continue
for src_label_id, src_cat in src_point_cat.items.items():
src_label = src_label_cat.items[src_label_id].name
dst_label_id = label_cat.find(src_label)[0]
dst_cat = dst_point_cat.items.get(dst_label_id)
if dst_cat is not None:
if dst_cat != src_cat:
raise ValueError("Can't merge point category for label "
"%s (from #%s): %s vs. %s" % \
(src_label, src_id, src_cat, dst_cat)
)
else:
pass
else:
dst_point_cat.add(dst_label_id,
src_cat.labels, src_cat.joints)
if len(dst_point_cat.items) == 0:
return None
return dst_point_cat
def _merge_mask_categories(self, sources, label_cat):
dst_mask_cat = MaskCategories()
for src_id, src_categories in enumerate(sources):
src_label_cat = src_categories.get(AnnotationType.label)
src_mask_cat = src_categories.get(AnnotationType.mask)
if src_label_cat is None or src_mask_cat is None:
continue
for src_label_id, src_cat in src_mask_cat.colormap.items():
src_label = src_label_cat.items[src_label_id].name
dst_label_id = label_cat.find(src_label)[0]
dst_cat = dst_mask_cat.colormap.get(dst_label_id)
if dst_cat is not None:
if dst_cat != src_cat:
raise ValueError("Can't merge mask category for label "
"%s (from #%s): %s vs. %s" % \
(src_label, src_id, src_cat, dst_cat)
)
else:
pass
else:
dst_mask_cat.colormap[dst_label_id] = src_cat
if len(dst_mask_cat.colormap) == 0:
return None
return dst_mask_cat
def _merge_categories(self, sources):
dst_categories = {}
label_cat = self._merge_label_categories(sources)
if label_cat is None:
return dst_categories
dst_categories[AnnotationType.label] = label_cat
points_cat = self._merge_point_categories(sources, label_cat)
if points_cat is not None:
dst_categories[AnnotationType.points] = points_cat
mask_cat = self._merge_mask_categories(sources, label_cat)
if mask_cat is not None:
dst_categories[AnnotationType.mask] = mask_cat
return dst_categories
def _match_annotations(self, sources):
all_by_type = {}
for s in sources:
@ -473,8 +596,29 @@ class IntersectMerge(MergingStrategy):
_check_group(group_labels, group)
def _get_label_name(self, label_id):
if label_id is None:
return None
return self._categories[AnnotationType.label].items[label_id].name
def _get_label_id(self, label):
return self._categories[AnnotationType.label].find(label)[0]
def _get_src_label_name(self, ann, label_id):
if label_id is None:
return None
item_id = self._ann_map[id(ann)][1]
dataset_id = self._item_map[item_id][1]
return self._dataset_map[dataset_id][0] \
.categories()[AnnotationType.label].items[label_id].name
def _get_any_label_name(self, ann, label_id):
if label_id is None:
return None
try:
return self._get_src_label_name(ann, label_id)
except KeyError:
return self._get_label_name(label_id)
def _check_groups_definition(self):
for group in self.conf.groups:
for label, _ in group:
@ -486,16 +630,19 @@ class IntersectMerge(MergingStrategy):
self._categories[AnnotationType.label].items])
)
@attrs
@attrs(kw_only=True)
class AnnotationMatcher:
_context = attrib(type=IntersectMerge, default=None)
def match_annotations(self, sources):
raise NotImplementedError()
@attrs
class LabelMatcher(AnnotationMatcher):
@staticmethod
def distance(a, b):
return a.label == b.label
def distance(self, a, b):
a_label = self._context._get_any_label_name(a, a.label)
b_label = self._context._get_any_label_name(b, b.label)
return a_label == b_label
def match_annotations(self, sources):
return [sum(sources, [])]
@ -507,6 +654,7 @@ class _ShapeMatcher(AnnotationMatcher):
def match_annotations(self, sources):
distance = self.distance
label_matcher = self.label_matcher
pairwise_dist = self.pairwise_dist
cluster_dist = self.cluster_dist
@ -537,9 +685,10 @@ class _ShapeMatcher(AnnotationMatcher):
for a_idx, src_a in enumerate(sources):
for src_b in sources[a_idx+1 :]:
matches, _, _, _ = match_segments(src_a, src_b,
dist_thresh=pairwise_dist, distance=distance)
for m in matches:
adjacent[id(m[0])].append(id(m[1]))
dist_thresh=pairwise_dist,
distance=distance, label_matcher=label_matcher)
for a, b in matches:
adjacent[id(a)].append(id(b))
# join all segments into matching clusters
clusters = []
@ -573,6 +722,11 @@ class _ShapeMatcher(AnnotationMatcher):
def distance(a, b):
return segment_iou(a, b)
def label_matcher(self, a, b):
a_label = self._context._get_any_label_name(a, a.label)
b_label = self._context._get_any_label_name(b, b.label)
return a_label == b_label
@attrs
class BboxMatcher(_ShapeMatcher):
pass
@ -626,8 +780,6 @@ class CaptionsMatcher(AnnotationMatcher):
@attrs(kw_only=True)
class AnnotationMerger:
_context = attrib(type=IntersectMerge, default=None)
def merge_clusters(self, clusters):
raise NotImplementedError()
@ -641,20 +793,22 @@ class LabelMerger(AnnotationMerger, LabelMatcher):
return []
votes = {} # label -> score
for label_ann in clusters[0]:
votes[label_ann.label] = 1 + votes.get(label_ann.label, 0)
for ann in clusters[0]:
label = self._context._get_src_label_name(ann, ann.label)
votes[label] = 1 + votes.get(label, 0)
merged = []
for label, count in votes.items():
if count < self.quorum:
sources = set(self.get_ann_source(id(a)) for a in clusters[0]
if label not in [l.label for l in a])
if label not in [self._context._get_src_label_name(l, l.label)
for l in a])
sources = [self._context._dataset_map[s][1] for s in sources]
self._context.add_item_error(FailedLabelVotingError,
sources, votes)
continue
merged.append(Label(label, attributes={
merged.append(Label(self._context._get_label_id(label), attributes={
'score': count / len(self._context._dataset_map)
}))
@ -682,14 +836,17 @@ class _ShapeMerger(AnnotationMerger, _ShapeMatcher):
def find_cluster_label(self, cluster):
votes = {}
for s in cluster:
state = votes.setdefault(s.label, [0, 0])
label = self._context._get_src_label_name(s, s.label)
state = votes.setdefault(label, [0, 0])
state[0] += s.attributes.get('score', 1.0)
state[1] += 1
label, (score, count) = max(votes.items(), key=lambda e: e[1][0])
if count < self.quorum:
self._context.add_item_error(FailedLabelVotingError, votes)
score = score / count if count else None
label = None
score = score / len(self._context._dataset_map)
label = self._context._get_label_id(label)
return label, score
@staticmethod
@ -729,11 +886,10 @@ class LineMerger(_ShapeMerger, LineMatcher):
class CaptionsMerger(AnnotationMerger, CaptionsMatcher):
pass
def match_segments(a_segms, b_segms, distance='iou', dist_thresh=1.0):
if distance == 'iou':
distance = segment_iou
else:
assert callable(distance)
def match_segments(a_segms, b_segms, distance=segment_iou, dist_thresh=1.0,
label_matcher=lambda a, b: a.label == b.label):
assert callable(distance), distance
assert callable(label_matcher), label_matcher
a_segms.sort(key=lambda ann: 1 - ann.attributes.get('score', 1))
b_segms.sort(key=lambda ann: 1 - ann.attributes.get('score', 1))
@ -753,13 +909,16 @@ def match_segments(a_segms, b_segms, distance='iou', dist_thresh=1.0):
for a_idx, a_segm in enumerate(a_segms):
if len(b_segms) == 0:
break
matched_b = a_matches[a_idx]
max_dist = max(distances[a_idx, matched_b], dist_thresh)
for b_idx, b_segm in enumerate(b_segms):
matched_b = -1
max_dist = -1
b_indices = np.argsort([not label_matcher(a_segm, b_segm)
for b_segm in b_segms],
kind='stable') # prioritize those with same label, keep score order
for b_idx in b_indices:
if 0 <= b_matches[b_idx]: # assign a_segm with max conf
continue
d = distances[a_idx, b_idx]
if d < max_dist:
if d < dist_thresh or d <= max_dist:
continue
max_dist = d
matched_b = b_idx
@ -771,7 +930,7 @@ def match_segments(a_segms, b_segms, distance='iou', dist_thresh=1.0):
b_segm = b_segms[matched_b]
if a_segm.label == b_segm.label:
if label_matcher(a_segm, b_segm):
matches.append( (a_segm, b_segm) )
else:
mispred.append( (a_segm, b_segm) )

@ -118,7 +118,7 @@ def segment_iou(a, b):
if ann.type == AnnotationType.polygon:
return mask_utils.frPyObjects([ann.points], h, w)
elif isinstance(ann, RleMask):
return [ann._rle]
return [ann.rle]
elif ann.type == AnnotationType.mask:
return mask_utils.frPyObjects([mask_to_rle(ann.image)], h, w)
else:

@ -7,6 +7,6 @@ matplotlib>=3.3.1
opencv-python-headless>=4.1.0.25
Pillow>=6.1.0
pycocotools>=2.0.0
PyYAML>=5.1.1
PyYAML>=5.3.1
scikit-image>=0.15.0
tensorboardX>=1.8

@ -3,7 +3,8 @@ from unittest import TestCase
import numpy as np
from datumaro.components.extractor import (Bbox, Caption, DatasetItem,
Extractor, Label, Mask, Points, Polygon, PolyLine)
Extractor, Label, Mask, Points, Polygon, PolyLine,
LabelCategories, PointsCategories, MaskCategories, AnnotationType)
from datumaro.components.operations import (FailedAttrVotingError,
IntersectMerge, NoMatchingAnnError, NoMatchingItemError, WrongGroupError,
compute_ann_statistics, mean_std)
@ -198,7 +199,7 @@ class TestMultimerge(TestCase):
Bbox(1, 2, 3, 4, label=1),
# common
Mask(label=3, z_order=2, image=np.array([
Mask(label=2, z_order=2, image=np.array([
[0, 0, 0, 0],
[0, 0, 0, 0],
[1, 1, 1, 0],
@ -218,7 +219,7 @@ class TestMultimerge(TestCase):
source1 = Dataset.from_iterable([
DatasetItem(1, annotations=[
# common
Mask(label=3, image=np.array([
Mask(label=2, image=np.array([
[0, 0, 0, 0],
[0, 1, 1, 1],
[0, 1, 1, 1],
@ -238,7 +239,7 @@ class TestMultimerge(TestCase):
source2 = Dataset.from_iterable([
DatasetItem(1, annotations=[
# common
Mask(label=3, z_order=3, image=np.array([
Mask(label=2, z_order=3, image=np.array([
[0, 0, 1, 1],
[0, 1, 1, 1],
[1, 1, 1, 1],
@ -261,7 +262,7 @@ class TestMultimerge(TestCase):
# common
# nearest to mean bbox
Mask(label=3, z_order=3, image=np.array([
Mask(label=2, z_order=3, image=np.array([
[0, 0, 0, 0],
[0, 1, 1, 1],
[0, 1, 1, 1],
@ -365,3 +366,86 @@ class TestMultimerge(TestCase):
self.assertEqual(3, len([e for e in merger.errors
if isinstance(e, WrongGroupError)]), merger.errors
)
def test_can_merge_classes(self):
source0 = Dataset.from_iterable([
DatasetItem(1, annotations=[
Label(0),
Label(1),
Bbox(0, 0, 1, 1, label=1),
]),
], categories=['a', 'b'])
source1 = Dataset.from_iterable([
DatasetItem(1, annotations=[
Label(0),
Label(1),
Bbox(0, 0, 1, 1, label=0),
Bbox(0, 0, 1, 1, label=1),
]),
], categories=['b', 'c'])
expected = Dataset.from_iterable([
DatasetItem(1, annotations=[
Label(0),
Label(1),
Label(2),
Bbox(0, 0, 1, 1, label=1),
Bbox(0, 0, 1, 1, label=2),
]),
], categories=['a', 'b', 'c'])
merger = IntersectMerge()
merged = merger([source0, source1])
compare_datasets(self, expected, merged, ignored_attrs={'score'})
def test_can_merge_categories(self):
source0 = Dataset.from_iterable([
DatasetItem(1, annotations=[ Label(0), ]),
], categories={
AnnotationType.label: LabelCategories.from_iterable(['a', 'b']),
AnnotationType.points: PointsCategories.from_iterable([
(0, ['l0', 'l1']),
(1, ['l2', 'l3']),
]),
AnnotationType.mask: MaskCategories({
0: (0, 1, 2),
1: (1, 2, 3),
}),
})
source1 = Dataset.from_iterable([
DatasetItem(1, annotations=[ Label(0), ]),
], categories={
AnnotationType.label: LabelCategories.from_iterable(['c', 'b']),
AnnotationType.points: PointsCategories.from_iterable([
(0, []),
(1, ['l2', 'l3']),
]),
AnnotationType.mask: MaskCategories({
0: (0, 2, 4),
1: (1, 2, 3),
}),
})
expected = Dataset.from_iterable([
DatasetItem(1, annotations=[ Label(0), Label(2), ]),
], categories={
AnnotationType.label: LabelCategories.from_iterable(['a', 'b', 'c']),
AnnotationType.points: PointsCategories.from_iterable([
(0, ['l0', 'l1']),
(1, ['l2', 'l3']),
(2, []),
]),
AnnotationType.mask: MaskCategories({
0: (0, 1, 2),
1: (1, 2, 3),
2: (0, 2, 4),
}),
})
merger = IntersectMerge()
merged = merger([source0, source1])
compare_datasets(self, expected, merged, ignored_attrs={'score'})
Loading…
Cancel
Save