[Datumaro] Merge with different categories (#2098)

* Add category merging

* Update error message

* Add category merging test

* update changelog

* Fix field access

* remove import

* Update CHANGELOG.md

Co-authored-by: Nikita Manovich <nikita.manovich@intel.com>
main
Maxim Zhiltsov 6 years ago committed by GitHub
parent 4dbfa3bfdf
commit ffb71fb7a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Ability to work with data on the fly (https://github.com/opencv/cvat/pull/2007) - Ability to work with data on the fly (https://github.com/opencv/cvat/pull/2007)
- Annotation in process outline color wheel (<https://github.com/opencv/cvat/pull/2084>) - Annotation in process outline color wheel (<https://github.com/opencv/cvat/pull/2084>)
- [Datumaro] CLI command for dataset equality comparison (<https://github.com/opencv/cvat/pull/1989>) - [Datumaro] CLI command for dataset equality comparison (<https://github.com/opencv/cvat/pull/1989>)
- [Datumaro] Merging of datasets with different labels (<https://github.com/opencv/cvat/pull/2098>)
### Changed ### Changed
- UI models (like DEXTR) were redesigned to be more interactive (<https://github.com/opencv/cvat/pull/2054>) - UI models (like DEXTR) were redesigned to be more interactive (<https://github.com/opencv/cvat/pull/2054>)

@ -49,7 +49,11 @@ class Categories:
@attrs @attrs
class LabelCategories(Categories): class LabelCategories(Categories):
Category = namedtuple('Category', ['name', 'parent', 'attributes']) @attrs(repr_ns='LabelCategories')
class Category:
name = attrib(converter=str, validator=not_empty)
parent = attrib(default='', validator=default_if_none(str))
attributes = attrib(factory=set, validator=default_if_none(set))
items = attrib(factory=list, validator=default_if_none(list)) items = attrib(factory=list, validator=default_if_none(list))
_indices = attrib(factory=dict, init=False, eq=False) _indices = attrib(factory=dict, init=False, eq=False)
@ -93,15 +97,6 @@ class LabelCategories(Categories):
def add(self, name: str, parent: str = None, attributes: dict = None): def add(self, name: str, parent: str = None, attributes: dict = None):
assert name not in self._indices, name assert name not in self._indices, name
if attributes is None:
attributes = set()
else:
if not isinstance(attributes, set):
attributes = set(attributes)
for attr in attributes:
assert isinstance(attr, str)
if parent is None:
parent = ''
index = len(self.items) index = len(self.items)
self.items.append(self.Category(name, parent, attributes)) self.items.append(self.Category(name, parent, attributes))
@ -386,7 +381,10 @@ setattr(Bbox, '__init__', Bbox.__actual_init__)
@attrs @attrs
class PointsCategories(Categories): class PointsCategories(Categories):
Category = namedtuple('Category', ['labels', 'joints']) @attrs(repr_ns="PointsCategories")
class Category:
labels = attrib(factory=list, validator=default_if_none(list))
joints = attrib(factory=set, validator=default_if_none(set))
items = attrib(factory=dict, validator=default_if_none(dict)) items = attrib(factory=dict, validator=default_if_none(dict))
@ -396,28 +394,19 @@ class PointsCategories(Categories):
Args: Args:
iterable ([type]): This iterable object can be: iterable ([type]): This iterable object can be:
1)simple int - will generate one Category with int as label 1) list of positional argumetns - will generate Categories
2)list of int - will interpreted as list of Category labels with these arguments
3)list of positional argumetns - will generate Categories
with this arguments
Returns: Returns:
PointsCategories: PointsCategories object PointsCategories: PointsCategories object
""" """
temp_categories = cls() temp_categories = cls()
if isinstance(iterable, int):
iterable = [[iterable]]
for category in iterable: for category in iterable:
if isinstance(category, int):
category = [category]
temp_categories.add(*category) temp_categories.add(*category)
return temp_categories return temp_categories
def add(self, label_id, labels=None, joints=None): def add(self, label_id, labels=None, joints=None):
if labels is None:
labels = []
if joints is None: if joints is None:
joints = [] joints = []
joints = set(map(tuple, joints)) joints = set(map(tuple, joints))

@ -15,7 +15,8 @@ from attr import attrib, attrs
from unittest import TestCase from unittest import TestCase
from datumaro.components.cli_plugin import CliPlugin from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.extractor import AnnotationType, Bbox, Label from datumaro.components.extractor import (AnnotationType, Bbox, Label,
LabelCategories, PointsCategories, MaskCategories)
from datumaro.components.project import Dataset from datumaro.components.project import Dataset
from datumaro.util import find, filter_dict from datumaro.util import find, filter_dict
from datumaro.util.attrs_util import ensure_cls, default_if_none from datumaro.util.attrs_util import ensure_cls, default_if_none
@ -53,7 +54,8 @@ def merge_categories(sources):
for cat_type, source_cat in source.items(): for cat_type, source_cat in source.items():
if not categories[cat_type] == source_cat: if not categories[cat_type] == source_cat:
raise NotImplementedError( raise NotImplementedError(
"Merging different categories is not implemented yet") "Merging of datasets with different categories is "
"only allowed in 'merge' command.")
return categories return categories
class MergingStrategy(CliPlugin): class MergingStrategy(CliPlugin):
@ -180,7 +182,8 @@ class IntersectMerge(MergingStrategy):
_categories = attrib(init=False) # merged categories _categories = attrib(init=False) # merged categories
def __call__(self, datasets): def __call__(self, datasets):
self._categories = merge_categories(d.categories() for d in datasets) self._categories = self._merge_categories(
[d.categories() for d in datasets])
merged = Dataset(categories=self._categories) merged = Dataset(categories=self._categories)
self._check_groups_definition() self._check_groups_definition()
@ -283,6 +286,126 @@ class IntersectMerge(MergingStrategy):
return matches, item_map return matches, item_map
def _merge_label_categories(self, sources):
same = True
common = None
for src_categories in sources:
src_cat = src_categories.get(AnnotationType.label)
if common is None:
common = src_cat
elif common != src_cat:
same = False
break
if same:
return common
dst_cat = LabelCategories()
for src_id, src_categories in enumerate(sources):
src_cat = src_categories.get(AnnotationType.label)
if src_cat is None:
continue
for src_label in src_cat.items:
dst_label = dst_cat.find(src_label.name)[1]
if dst_label is not None:
if dst_label != src_label:
if src_label.parent and dst_label.parent and \
src_label.parent != dst_label.parent:
raise ValueError("Can't merge label category "
"%s (from #%s): "
"parent label conflict: %s vs. %s" % \
(src_label.name, src_id,
src_label.parent, dst_label.parent)
)
dst_label.parent = dst_label.parent or src_label.parent
dst_label.attributes |= src_label.attributes
else:
pass
else:
dst_cat.add(src_label.name,
src_label.parent, src_label.attributes)
return dst_cat
def _merge_point_categories(self, sources, label_cat):
dst_point_cat = PointsCategories()
for src_id, src_categories in enumerate(sources):
src_label_cat = src_categories.get(AnnotationType.label)
src_point_cat = src_categories.get(AnnotationType.points)
if src_label_cat is None or src_point_cat is None:
continue
for src_label_id, src_cat in src_point_cat.items.items():
src_label = src_label_cat.items[src_label_id].name
dst_label_id = label_cat.find(src_label)[0]
dst_cat = dst_point_cat.items.get(dst_label_id)
if dst_cat is not None:
if dst_cat != src_cat:
raise ValueError("Can't merge point category for label "
"%s (from #%s): %s vs. %s" % \
(src_label, src_id, src_cat, dst_cat)
)
else:
pass
else:
dst_point_cat.add(dst_label_id,
src_cat.labels, src_cat.joints)
if len(dst_point_cat.items) == 0:
return None
return dst_point_cat
def _merge_mask_categories(self, sources, label_cat):
dst_mask_cat = MaskCategories()
for src_id, src_categories in enumerate(sources):
src_label_cat = src_categories.get(AnnotationType.label)
src_mask_cat = src_categories.get(AnnotationType.mask)
if src_label_cat is None or src_mask_cat is None:
continue
for src_label_id, src_cat in src_mask_cat.colormap.items():
src_label = src_label_cat.items[src_label_id].name
dst_label_id = label_cat.find(src_label)[0]
dst_cat = dst_mask_cat.colormap.get(dst_label_id)
if dst_cat is not None:
if dst_cat != src_cat:
raise ValueError("Can't merge mask category for label "
"%s (from #%s): %s vs. %s" % \
(src_label, src_id, src_cat, dst_cat)
)
else:
pass
else:
dst_mask_cat.colormap[dst_label_id] = src_cat
if len(dst_mask_cat.colormap) == 0:
return None
return dst_mask_cat
def _merge_categories(self, sources):
dst_categories = {}
label_cat = self._merge_label_categories(sources)
if label_cat is None:
return dst_categories
dst_categories[AnnotationType.label] = label_cat
points_cat = self._merge_point_categories(sources, label_cat)
if points_cat is not None:
dst_categories[AnnotationType.points] = points_cat
mask_cat = self._merge_mask_categories(sources, label_cat)
if mask_cat is not None:
dst_categories[AnnotationType.mask] = mask_cat
return dst_categories
def _match_annotations(self, sources): def _match_annotations(self, sources):
all_by_type = {} all_by_type = {}
for s in sources: for s in sources:
@ -473,8 +596,29 @@ class IntersectMerge(MergingStrategy):
_check_group(group_labels, group) _check_group(group_labels, group)
def _get_label_name(self, label_id): def _get_label_name(self, label_id):
if label_id is None:
return None
return self._categories[AnnotationType.label].items[label_id].name return self._categories[AnnotationType.label].items[label_id].name
def _get_label_id(self, label):
return self._categories[AnnotationType.label].find(label)[0]
def _get_src_label_name(self, ann, label_id):
if label_id is None:
return None
item_id = self._ann_map[id(ann)][1]
dataset_id = self._item_map[item_id][1]
return self._dataset_map[dataset_id][0] \
.categories()[AnnotationType.label].items[label_id].name
def _get_any_label_name(self, ann, label_id):
if label_id is None:
return None
try:
return self._get_src_label_name(ann, label_id)
except KeyError:
return self._get_label_name(label_id)
def _check_groups_definition(self): def _check_groups_definition(self):
for group in self.conf.groups: for group in self.conf.groups:
for label, _ in group: for label, _ in group:
@ -486,16 +630,19 @@ class IntersectMerge(MergingStrategy):
self._categories[AnnotationType.label].items]) self._categories[AnnotationType.label].items])
) )
@attrs @attrs(kw_only=True)
class AnnotationMatcher: class AnnotationMatcher:
_context = attrib(type=IntersectMerge, default=None)
def match_annotations(self, sources): def match_annotations(self, sources):
raise NotImplementedError() raise NotImplementedError()
@attrs @attrs
class LabelMatcher(AnnotationMatcher): class LabelMatcher(AnnotationMatcher):
@staticmethod def distance(self, a, b):
def distance(a, b): a_label = self._context._get_any_label_name(a, a.label)
return a.label == b.label b_label = self._context._get_any_label_name(b, b.label)
return a_label == b_label
def match_annotations(self, sources): def match_annotations(self, sources):
return [sum(sources, [])] return [sum(sources, [])]
@ -507,6 +654,7 @@ class _ShapeMatcher(AnnotationMatcher):
def match_annotations(self, sources): def match_annotations(self, sources):
distance = self.distance distance = self.distance
label_matcher = self.label_matcher
pairwise_dist = self.pairwise_dist pairwise_dist = self.pairwise_dist
cluster_dist = self.cluster_dist cluster_dist = self.cluster_dist
@ -537,9 +685,10 @@ class _ShapeMatcher(AnnotationMatcher):
for a_idx, src_a in enumerate(sources): for a_idx, src_a in enumerate(sources):
for src_b in sources[a_idx+1 :]: for src_b in sources[a_idx+1 :]:
matches, _, _, _ = match_segments(src_a, src_b, matches, _, _, _ = match_segments(src_a, src_b,
dist_thresh=pairwise_dist, distance=distance) dist_thresh=pairwise_dist,
for m in matches: distance=distance, label_matcher=label_matcher)
adjacent[id(m[0])].append(id(m[1])) for a, b in matches:
adjacent[id(a)].append(id(b))
# join all segments into matching clusters # join all segments into matching clusters
clusters = [] clusters = []
@ -573,6 +722,11 @@ class _ShapeMatcher(AnnotationMatcher):
def distance(a, b): def distance(a, b):
return segment_iou(a, b) return segment_iou(a, b)
def label_matcher(self, a, b):
a_label = self._context._get_any_label_name(a, a.label)
b_label = self._context._get_any_label_name(b, b.label)
return a_label == b_label
@attrs @attrs
class BboxMatcher(_ShapeMatcher): class BboxMatcher(_ShapeMatcher):
pass pass
@ -626,8 +780,6 @@ class CaptionsMatcher(AnnotationMatcher):
@attrs(kw_only=True) @attrs(kw_only=True)
class AnnotationMerger: class AnnotationMerger:
_context = attrib(type=IntersectMerge, default=None)
def merge_clusters(self, clusters): def merge_clusters(self, clusters):
raise NotImplementedError() raise NotImplementedError()
@ -641,20 +793,22 @@ class LabelMerger(AnnotationMerger, LabelMatcher):
return [] return []
votes = {} # label -> score votes = {} # label -> score
for label_ann in clusters[0]: for ann in clusters[0]:
votes[label_ann.label] = 1 + votes.get(label_ann.label, 0) label = self._context._get_src_label_name(ann, ann.label)
votes[label] = 1 + votes.get(label, 0)
merged = [] merged = []
for label, count in votes.items(): for label, count in votes.items():
if count < self.quorum: if count < self.quorum:
sources = set(self.get_ann_source(id(a)) for a in clusters[0] sources = set(self.get_ann_source(id(a)) for a in clusters[0]
if label not in [l.label for l in a]) if label not in [self._context._get_src_label_name(l, l.label)
for l in a])
sources = [self._context._dataset_map[s][1] for s in sources] sources = [self._context._dataset_map[s][1] for s in sources]
self._context.add_item_error(FailedLabelVotingError, self._context.add_item_error(FailedLabelVotingError,
sources, votes) sources, votes)
continue continue
merged.append(Label(label, attributes={ merged.append(Label(self._context._get_label_id(label), attributes={
'score': count / len(self._context._dataset_map) 'score': count / len(self._context._dataset_map)
})) }))
@ -682,14 +836,17 @@ class _ShapeMerger(AnnotationMerger, _ShapeMatcher):
def find_cluster_label(self, cluster): def find_cluster_label(self, cluster):
votes = {} votes = {}
for s in cluster: for s in cluster:
state = votes.setdefault(s.label, [0, 0]) label = self._context._get_src_label_name(s, s.label)
state = votes.setdefault(label, [0, 0])
state[0] += s.attributes.get('score', 1.0) state[0] += s.attributes.get('score', 1.0)
state[1] += 1 state[1] += 1
label, (score, count) = max(votes.items(), key=lambda e: e[1][0]) label, (score, count) = max(votes.items(), key=lambda e: e[1][0])
if count < self.quorum: if count < self.quorum:
self._context.add_item_error(FailedLabelVotingError, votes) self._context.add_item_error(FailedLabelVotingError, votes)
score = score / count if count else None label = None
score = score / len(self._context._dataset_map)
label = self._context._get_label_id(label)
return label, score return label, score
@staticmethod @staticmethod
@ -729,11 +886,10 @@ class LineMerger(_ShapeMerger, LineMatcher):
class CaptionsMerger(AnnotationMerger, CaptionsMatcher): class CaptionsMerger(AnnotationMerger, CaptionsMatcher):
pass pass
def match_segments(a_segms, b_segms, distance='iou', dist_thresh=1.0): def match_segments(a_segms, b_segms, distance=segment_iou, dist_thresh=1.0,
if distance == 'iou': label_matcher=lambda a, b: a.label == b.label):
distance = segment_iou assert callable(distance), distance
else: assert callable(label_matcher), label_matcher
assert callable(distance)
a_segms.sort(key=lambda ann: 1 - ann.attributes.get('score', 1)) a_segms.sort(key=lambda ann: 1 - ann.attributes.get('score', 1))
b_segms.sort(key=lambda ann: 1 - ann.attributes.get('score', 1)) b_segms.sort(key=lambda ann: 1 - ann.attributes.get('score', 1))
@ -753,13 +909,16 @@ def match_segments(a_segms, b_segms, distance='iou', dist_thresh=1.0):
for a_idx, a_segm in enumerate(a_segms): for a_idx, a_segm in enumerate(a_segms):
if len(b_segms) == 0: if len(b_segms) == 0:
break break
matched_b = a_matches[a_idx] matched_b = -1
max_dist = max(distances[a_idx, matched_b], dist_thresh) max_dist = -1
for b_idx, b_segm in enumerate(b_segms): b_indices = np.argsort([not label_matcher(a_segm, b_segm)
for b_segm in b_segms],
kind='stable') # prioritize those with same label, keep score order
for b_idx in b_indices:
if 0 <= b_matches[b_idx]: # assign a_segm with max conf if 0 <= b_matches[b_idx]: # assign a_segm with max conf
continue continue
d = distances[a_idx, b_idx] d = distances[a_idx, b_idx]
if d < max_dist: if d < dist_thresh or d <= max_dist:
continue continue
max_dist = d max_dist = d
matched_b = b_idx matched_b = b_idx
@ -771,7 +930,7 @@ def match_segments(a_segms, b_segms, distance='iou', dist_thresh=1.0):
b_segm = b_segms[matched_b] b_segm = b_segms[matched_b]
if a_segm.label == b_segm.label: if label_matcher(a_segm, b_segm):
matches.append( (a_segm, b_segm) ) matches.append( (a_segm, b_segm) )
else: else:
mispred.append( (a_segm, b_segm) ) mispred.append( (a_segm, b_segm) )

@ -118,7 +118,7 @@ def segment_iou(a, b):
if ann.type == AnnotationType.polygon: if ann.type == AnnotationType.polygon:
return mask_utils.frPyObjects([ann.points], h, w) return mask_utils.frPyObjects([ann.points], h, w)
elif isinstance(ann, RleMask): elif isinstance(ann, RleMask):
return [ann._rle] return [ann.rle]
elif ann.type == AnnotationType.mask: elif ann.type == AnnotationType.mask:
return mask_utils.frPyObjects([mask_to_rle(ann.image)], h, w) return mask_utils.frPyObjects([mask_to_rle(ann.image)], h, w)
else: else:

@ -7,6 +7,6 @@ matplotlib>=3.3.1
opencv-python-headless>=4.1.0.25 opencv-python-headless>=4.1.0.25
Pillow>=6.1.0 Pillow>=6.1.0
pycocotools>=2.0.0 pycocotools>=2.0.0
PyYAML>=5.1.1 PyYAML>=5.3.1
scikit-image>=0.15.0 scikit-image>=0.15.0
tensorboardX>=1.8 tensorboardX>=1.8

@ -3,7 +3,8 @@ from unittest import TestCase
import numpy as np import numpy as np
from datumaro.components.extractor import (Bbox, Caption, DatasetItem, from datumaro.components.extractor import (Bbox, Caption, DatasetItem,
Extractor, Label, Mask, Points, Polygon, PolyLine) Extractor, Label, Mask, Points, Polygon, PolyLine,
LabelCategories, PointsCategories, MaskCategories, AnnotationType)
from datumaro.components.operations import (FailedAttrVotingError, from datumaro.components.operations import (FailedAttrVotingError,
IntersectMerge, NoMatchingAnnError, NoMatchingItemError, WrongGroupError, IntersectMerge, NoMatchingAnnError, NoMatchingItemError, WrongGroupError,
compute_ann_statistics, mean_std) compute_ann_statistics, mean_std)
@ -198,7 +199,7 @@ class TestMultimerge(TestCase):
Bbox(1, 2, 3, 4, label=1), Bbox(1, 2, 3, 4, label=1),
# common # common
Mask(label=3, z_order=2, image=np.array([ Mask(label=2, z_order=2, image=np.array([
[0, 0, 0, 0], [0, 0, 0, 0],
[0, 0, 0, 0], [0, 0, 0, 0],
[1, 1, 1, 0], [1, 1, 1, 0],
@ -218,7 +219,7 @@ class TestMultimerge(TestCase):
source1 = Dataset.from_iterable([ source1 = Dataset.from_iterable([
DatasetItem(1, annotations=[ DatasetItem(1, annotations=[
# common # common
Mask(label=3, image=np.array([ Mask(label=2, image=np.array([
[0, 0, 0, 0], [0, 0, 0, 0],
[0, 1, 1, 1], [0, 1, 1, 1],
[0, 1, 1, 1], [0, 1, 1, 1],
@ -238,7 +239,7 @@ class TestMultimerge(TestCase):
source2 = Dataset.from_iterable([ source2 = Dataset.from_iterable([
DatasetItem(1, annotations=[ DatasetItem(1, annotations=[
# common # common
Mask(label=3, z_order=3, image=np.array([ Mask(label=2, z_order=3, image=np.array([
[0, 0, 1, 1], [0, 0, 1, 1],
[0, 1, 1, 1], [0, 1, 1, 1],
[1, 1, 1, 1], [1, 1, 1, 1],
@ -261,7 +262,7 @@ class TestMultimerge(TestCase):
# common # common
# nearest to mean bbox # nearest to mean bbox
Mask(label=3, z_order=3, image=np.array([ Mask(label=2, z_order=3, image=np.array([
[0, 0, 0, 0], [0, 0, 0, 0],
[0, 1, 1, 1], [0, 1, 1, 1],
[0, 1, 1, 1], [0, 1, 1, 1],
@ -365,3 +366,86 @@ class TestMultimerge(TestCase):
self.assertEqual(3, len([e for e in merger.errors self.assertEqual(3, len([e for e in merger.errors
if isinstance(e, WrongGroupError)]), merger.errors if isinstance(e, WrongGroupError)]), merger.errors
) )
def test_can_merge_classes(self):
source0 = Dataset.from_iterable([
DatasetItem(1, annotations=[
Label(0),
Label(1),
Bbox(0, 0, 1, 1, label=1),
]),
], categories=['a', 'b'])
source1 = Dataset.from_iterable([
DatasetItem(1, annotations=[
Label(0),
Label(1),
Bbox(0, 0, 1, 1, label=0),
Bbox(0, 0, 1, 1, label=1),
]),
], categories=['b', 'c'])
expected = Dataset.from_iterable([
DatasetItem(1, annotations=[
Label(0),
Label(1),
Label(2),
Bbox(0, 0, 1, 1, label=1),
Bbox(0, 0, 1, 1, label=2),
]),
], categories=['a', 'b', 'c'])
merger = IntersectMerge()
merged = merger([source0, source1])
compare_datasets(self, expected, merged, ignored_attrs={'score'})
def test_can_merge_categories(self):
source0 = Dataset.from_iterable([
DatasetItem(1, annotations=[ Label(0), ]),
], categories={
AnnotationType.label: LabelCategories.from_iterable(['a', 'b']),
AnnotationType.points: PointsCategories.from_iterable([
(0, ['l0', 'l1']),
(1, ['l2', 'l3']),
]),
AnnotationType.mask: MaskCategories({
0: (0, 1, 2),
1: (1, 2, 3),
}),
})
source1 = Dataset.from_iterable([
DatasetItem(1, annotations=[ Label(0), ]),
], categories={
AnnotationType.label: LabelCategories.from_iterable(['c', 'b']),
AnnotationType.points: PointsCategories.from_iterable([
(0, []),
(1, ['l2', 'l3']),
]),
AnnotationType.mask: MaskCategories({
0: (0, 2, 4),
1: (1, 2, 3),
}),
})
expected = Dataset.from_iterable([
DatasetItem(1, annotations=[ Label(0), Label(2), ]),
], categories={
AnnotationType.label: LabelCategories.from_iterable(['a', 'b', 'c']),
AnnotationType.points: PointsCategories.from_iterable([
(0, ['l0', 'l1']),
(1, ['l2', 'l3']),
(2, []),
]),
AnnotationType.mask: MaskCategories({
0: (0, 1, 2),
1: (1, 2, 3),
2: (0, 2, 4),
}),
})
merger = IntersectMerge()
merged = merger([source0, source1])
compare_datasets(self, expected, merged, ignored_attrs={'score'})
Loading…
Cancel
Save