|
|
|
|
@ -15,7 +15,8 @@ from attr import attrib, attrs
|
|
|
|
|
from unittest import TestCase
|
|
|
|
|
|
|
|
|
|
from datumaro.components.cli_plugin import CliPlugin
|
|
|
|
|
from datumaro.components.extractor import AnnotationType, Bbox, Label
|
|
|
|
|
from datumaro.components.extractor import (AnnotationType, Bbox, Label,
|
|
|
|
|
LabelCategories, PointsCategories, MaskCategories)
|
|
|
|
|
from datumaro.components.project import Dataset
|
|
|
|
|
from datumaro.util import find, filter_dict
|
|
|
|
|
from datumaro.util.attrs_util import ensure_cls, default_if_none
|
|
|
|
|
@ -53,7 +54,8 @@ def merge_categories(sources):
|
|
|
|
|
for cat_type, source_cat in source.items():
|
|
|
|
|
if not categories[cat_type] == source_cat:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Merging different categories is not implemented yet")
|
|
|
|
|
"Merging of datasets with different categories is "
|
|
|
|
|
"only allowed in 'merge' command.")
|
|
|
|
|
return categories
|
|
|
|
|
|
|
|
|
|
class MergingStrategy(CliPlugin):
|
|
|
|
|
@ -180,7 +182,8 @@ class IntersectMerge(MergingStrategy):
|
|
|
|
|
_categories = attrib(init=False) # merged categories
|
|
|
|
|
|
|
|
|
|
def __call__(self, datasets):
|
|
|
|
|
self._categories = merge_categories(d.categories() for d in datasets)
|
|
|
|
|
self._categories = self._merge_categories(
|
|
|
|
|
[d.categories() for d in datasets])
|
|
|
|
|
merged = Dataset(categories=self._categories)
|
|
|
|
|
|
|
|
|
|
self._check_groups_definition()
|
|
|
|
|
@ -283,6 +286,126 @@ class IntersectMerge(MergingStrategy):
|
|
|
|
|
|
|
|
|
|
return matches, item_map
|
|
|
|
|
|
|
|
|
|
def _merge_label_categories(self, sources):
|
|
|
|
|
same = True
|
|
|
|
|
common = None
|
|
|
|
|
for src_categories in sources:
|
|
|
|
|
src_cat = src_categories.get(AnnotationType.label)
|
|
|
|
|
if common is None:
|
|
|
|
|
common = src_cat
|
|
|
|
|
elif common != src_cat:
|
|
|
|
|
same = False
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if same:
|
|
|
|
|
return common
|
|
|
|
|
|
|
|
|
|
dst_cat = LabelCategories()
|
|
|
|
|
for src_id, src_categories in enumerate(sources):
|
|
|
|
|
src_cat = src_categories.get(AnnotationType.label)
|
|
|
|
|
if src_cat is None:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
for src_label in src_cat.items:
|
|
|
|
|
dst_label = dst_cat.find(src_label.name)[1]
|
|
|
|
|
if dst_label is not None:
|
|
|
|
|
if dst_label != src_label:
|
|
|
|
|
if src_label.parent and dst_label.parent and \
|
|
|
|
|
src_label.parent != dst_label.parent:
|
|
|
|
|
raise ValueError("Can't merge label category "
|
|
|
|
|
"%s (from #%s): "
|
|
|
|
|
"parent label conflict: %s vs. %s" % \
|
|
|
|
|
(src_label.name, src_id,
|
|
|
|
|
src_label.parent, dst_label.parent)
|
|
|
|
|
)
|
|
|
|
|
dst_label.parent = dst_label.parent or src_label.parent
|
|
|
|
|
dst_label.attributes |= src_label.attributes
|
|
|
|
|
else:
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
dst_cat.add(src_label.name,
|
|
|
|
|
src_label.parent, src_label.attributes)
|
|
|
|
|
|
|
|
|
|
return dst_cat
|
|
|
|
|
|
|
|
|
|
def _merge_point_categories(self, sources, label_cat):
|
|
|
|
|
dst_point_cat = PointsCategories()
|
|
|
|
|
|
|
|
|
|
for src_id, src_categories in enumerate(sources):
|
|
|
|
|
src_label_cat = src_categories.get(AnnotationType.label)
|
|
|
|
|
src_point_cat = src_categories.get(AnnotationType.points)
|
|
|
|
|
if src_label_cat is None or src_point_cat is None:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
for src_label_id, src_cat in src_point_cat.items.items():
|
|
|
|
|
src_label = src_label_cat.items[src_label_id].name
|
|
|
|
|
dst_label_id = label_cat.find(src_label)[0]
|
|
|
|
|
dst_cat = dst_point_cat.items.get(dst_label_id)
|
|
|
|
|
if dst_cat is not None:
|
|
|
|
|
if dst_cat != src_cat:
|
|
|
|
|
raise ValueError("Can't merge point category for label "
|
|
|
|
|
"%s (from #%s): %s vs. %s" % \
|
|
|
|
|
(src_label, src_id, src_cat, dst_cat)
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
dst_point_cat.add(dst_label_id,
|
|
|
|
|
src_cat.labels, src_cat.joints)
|
|
|
|
|
|
|
|
|
|
if len(dst_point_cat.items) == 0:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
return dst_point_cat
|
|
|
|
|
|
|
|
|
|
def _merge_mask_categories(self, sources, label_cat):
|
|
|
|
|
dst_mask_cat = MaskCategories()
|
|
|
|
|
|
|
|
|
|
for src_id, src_categories in enumerate(sources):
|
|
|
|
|
src_label_cat = src_categories.get(AnnotationType.label)
|
|
|
|
|
src_mask_cat = src_categories.get(AnnotationType.mask)
|
|
|
|
|
if src_label_cat is None or src_mask_cat is None:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
for src_label_id, src_cat in src_mask_cat.colormap.items():
|
|
|
|
|
src_label = src_label_cat.items[src_label_id].name
|
|
|
|
|
dst_label_id = label_cat.find(src_label)[0]
|
|
|
|
|
dst_cat = dst_mask_cat.colormap.get(dst_label_id)
|
|
|
|
|
if dst_cat is not None:
|
|
|
|
|
if dst_cat != src_cat:
|
|
|
|
|
raise ValueError("Can't merge mask category for label "
|
|
|
|
|
"%s (from #%s): %s vs. %s" % \
|
|
|
|
|
(src_label, src_id, src_cat, dst_cat)
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
dst_mask_cat.colormap[dst_label_id] = src_cat
|
|
|
|
|
|
|
|
|
|
if len(dst_mask_cat.colormap) == 0:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
return dst_mask_cat
|
|
|
|
|
|
|
|
|
|
def _merge_categories(self, sources):
|
|
|
|
|
dst_categories = {}
|
|
|
|
|
|
|
|
|
|
label_cat = self._merge_label_categories(sources)
|
|
|
|
|
if label_cat is None:
|
|
|
|
|
return dst_categories
|
|
|
|
|
|
|
|
|
|
dst_categories[AnnotationType.label] = label_cat
|
|
|
|
|
|
|
|
|
|
points_cat = self._merge_point_categories(sources, label_cat)
|
|
|
|
|
if points_cat is not None:
|
|
|
|
|
dst_categories[AnnotationType.points] = points_cat
|
|
|
|
|
|
|
|
|
|
mask_cat = self._merge_mask_categories(sources, label_cat)
|
|
|
|
|
if mask_cat is not None:
|
|
|
|
|
dst_categories[AnnotationType.mask] = mask_cat
|
|
|
|
|
|
|
|
|
|
return dst_categories
|
|
|
|
|
|
|
|
|
|
def _match_annotations(self, sources):
|
|
|
|
|
all_by_type = {}
|
|
|
|
|
for s in sources:
|
|
|
|
|
@ -473,8 +596,29 @@ class IntersectMerge(MergingStrategy):
|
|
|
|
|
_check_group(group_labels, group)
|
|
|
|
|
|
|
|
|
|
def _get_label_name(self, label_id):
|
|
|
|
|
if label_id is None:
|
|
|
|
|
return None
|
|
|
|
|
return self._categories[AnnotationType.label].items[label_id].name
|
|
|
|
|
|
|
|
|
|
def _get_label_id(self, label):
|
|
|
|
|
return self._categories[AnnotationType.label].find(label)[0]
|
|
|
|
|
|
|
|
|
|
def _get_src_label_name(self, ann, label_id):
|
|
|
|
|
if label_id is None:
|
|
|
|
|
return None
|
|
|
|
|
item_id = self._ann_map[id(ann)][1]
|
|
|
|
|
dataset_id = self._item_map[item_id][1]
|
|
|
|
|
return self._dataset_map[dataset_id][0] \
|
|
|
|
|
.categories()[AnnotationType.label].items[label_id].name
|
|
|
|
|
|
|
|
|
|
def _get_any_label_name(self, ann, label_id):
|
|
|
|
|
if label_id is None:
|
|
|
|
|
return None
|
|
|
|
|
try:
|
|
|
|
|
return self._get_src_label_name(ann, label_id)
|
|
|
|
|
except KeyError:
|
|
|
|
|
return self._get_label_name(label_id)
|
|
|
|
|
|
|
|
|
|
def _check_groups_definition(self):
|
|
|
|
|
for group in self.conf.groups:
|
|
|
|
|
for label, _ in group:
|
|
|
|
|
@ -486,16 +630,19 @@ class IntersectMerge(MergingStrategy):
|
|
|
|
|
self._categories[AnnotationType.label].items])
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@attrs
|
|
|
|
|
@attrs(kw_only=True)
|
|
|
|
|
class AnnotationMatcher:
|
|
|
|
|
_context = attrib(type=IntersectMerge, default=None)
|
|
|
|
|
|
|
|
|
|
def match_annotations(self, sources):
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
@attrs
|
|
|
|
|
class LabelMatcher(AnnotationMatcher):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def distance(a, b):
|
|
|
|
|
return a.label == b.label
|
|
|
|
|
def distance(self, a, b):
|
|
|
|
|
a_label = self._context._get_any_label_name(a, a.label)
|
|
|
|
|
b_label = self._context._get_any_label_name(b, b.label)
|
|
|
|
|
return a_label == b_label
|
|
|
|
|
|
|
|
|
|
def match_annotations(self, sources):
|
|
|
|
|
return [sum(sources, [])]
|
|
|
|
|
@ -507,6 +654,7 @@ class _ShapeMatcher(AnnotationMatcher):
|
|
|
|
|
|
|
|
|
|
def match_annotations(self, sources):
|
|
|
|
|
distance = self.distance
|
|
|
|
|
label_matcher = self.label_matcher
|
|
|
|
|
pairwise_dist = self.pairwise_dist
|
|
|
|
|
cluster_dist = self.cluster_dist
|
|
|
|
|
|
|
|
|
|
@ -537,9 +685,10 @@ class _ShapeMatcher(AnnotationMatcher):
|
|
|
|
|
for a_idx, src_a in enumerate(sources):
|
|
|
|
|
for src_b in sources[a_idx+1 :]:
|
|
|
|
|
matches, _, _, _ = match_segments(src_a, src_b,
|
|
|
|
|
dist_thresh=pairwise_dist, distance=distance)
|
|
|
|
|
for m in matches:
|
|
|
|
|
adjacent[id(m[0])].append(id(m[1]))
|
|
|
|
|
dist_thresh=pairwise_dist,
|
|
|
|
|
distance=distance, label_matcher=label_matcher)
|
|
|
|
|
for a, b in matches:
|
|
|
|
|
adjacent[id(a)].append(id(b))
|
|
|
|
|
|
|
|
|
|
# join all segments into matching clusters
|
|
|
|
|
clusters = []
|
|
|
|
|
@ -573,6 +722,11 @@ class _ShapeMatcher(AnnotationMatcher):
|
|
|
|
|
def distance(a, b):
|
|
|
|
|
return segment_iou(a, b)
|
|
|
|
|
|
|
|
|
|
def label_matcher(self, a, b):
|
|
|
|
|
a_label = self._context._get_any_label_name(a, a.label)
|
|
|
|
|
b_label = self._context._get_any_label_name(b, b.label)
|
|
|
|
|
return a_label == b_label
|
|
|
|
|
|
|
|
|
|
@attrs
|
|
|
|
|
class BboxMatcher(_ShapeMatcher):
|
|
|
|
|
pass
|
|
|
|
|
@ -626,8 +780,6 @@ class CaptionsMatcher(AnnotationMatcher):
|
|
|
|
|
|
|
|
|
|
@attrs(kw_only=True)
|
|
|
|
|
class AnnotationMerger:
|
|
|
|
|
_context = attrib(type=IntersectMerge, default=None)
|
|
|
|
|
|
|
|
|
|
def merge_clusters(self, clusters):
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
@ -641,20 +793,22 @@ class LabelMerger(AnnotationMerger, LabelMatcher):
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
votes = {} # label -> score
|
|
|
|
|
for label_ann in clusters[0]:
|
|
|
|
|
votes[label_ann.label] = 1 + votes.get(label_ann.label, 0)
|
|
|
|
|
for ann in clusters[0]:
|
|
|
|
|
label = self._context._get_src_label_name(ann, ann.label)
|
|
|
|
|
votes[label] = 1 + votes.get(label, 0)
|
|
|
|
|
|
|
|
|
|
merged = []
|
|
|
|
|
for label, count in votes.items():
|
|
|
|
|
if count < self.quorum:
|
|
|
|
|
sources = set(self.get_ann_source(id(a)) for a in clusters[0]
|
|
|
|
|
if label not in [l.label for l in a])
|
|
|
|
|
if label not in [self._context._get_src_label_name(l, l.label)
|
|
|
|
|
for l in a])
|
|
|
|
|
sources = [self._context._dataset_map[s][1] for s in sources]
|
|
|
|
|
self._context.add_item_error(FailedLabelVotingError,
|
|
|
|
|
sources, votes)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
merged.append(Label(label, attributes={
|
|
|
|
|
merged.append(Label(self._context._get_label_id(label), attributes={
|
|
|
|
|
'score': count / len(self._context._dataset_map)
|
|
|
|
|
}))
|
|
|
|
|
|
|
|
|
|
@ -682,14 +836,17 @@ class _ShapeMerger(AnnotationMerger, _ShapeMatcher):
|
|
|
|
|
def find_cluster_label(self, cluster):
|
|
|
|
|
votes = {}
|
|
|
|
|
for s in cluster:
|
|
|
|
|
state = votes.setdefault(s.label, [0, 0])
|
|
|
|
|
label = self._context._get_src_label_name(s, s.label)
|
|
|
|
|
state = votes.setdefault(label, [0, 0])
|
|
|
|
|
state[0] += s.attributes.get('score', 1.0)
|
|
|
|
|
state[1] += 1
|
|
|
|
|
|
|
|
|
|
label, (score, count) = max(votes.items(), key=lambda e: e[1][0])
|
|
|
|
|
if count < self.quorum:
|
|
|
|
|
self._context.add_item_error(FailedLabelVotingError, votes)
|
|
|
|
|
score = score / count if count else None
|
|
|
|
|
label = None
|
|
|
|
|
score = score / len(self._context._dataset_map)
|
|
|
|
|
label = self._context._get_label_id(label)
|
|
|
|
|
return label, score
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
@ -729,11 +886,10 @@ class LineMerger(_ShapeMerger, LineMatcher):
|
|
|
|
|
class CaptionsMerger(AnnotationMerger, CaptionsMatcher):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def match_segments(a_segms, b_segms, distance='iou', dist_thresh=1.0):
|
|
|
|
|
if distance == 'iou':
|
|
|
|
|
distance = segment_iou
|
|
|
|
|
else:
|
|
|
|
|
assert callable(distance)
|
|
|
|
|
def match_segments(a_segms, b_segms, distance=segment_iou, dist_thresh=1.0,
|
|
|
|
|
label_matcher=lambda a, b: a.label == b.label):
|
|
|
|
|
assert callable(distance), distance
|
|
|
|
|
assert callable(label_matcher), label_matcher
|
|
|
|
|
|
|
|
|
|
a_segms.sort(key=lambda ann: 1 - ann.attributes.get('score', 1))
|
|
|
|
|
b_segms.sort(key=lambda ann: 1 - ann.attributes.get('score', 1))
|
|
|
|
|
@ -753,13 +909,16 @@ def match_segments(a_segms, b_segms, distance='iou', dist_thresh=1.0):
|
|
|
|
|
for a_idx, a_segm in enumerate(a_segms):
|
|
|
|
|
if len(b_segms) == 0:
|
|
|
|
|
break
|
|
|
|
|
matched_b = a_matches[a_idx]
|
|
|
|
|
max_dist = max(distances[a_idx, matched_b], dist_thresh)
|
|
|
|
|
for b_idx, b_segm in enumerate(b_segms):
|
|
|
|
|
matched_b = -1
|
|
|
|
|
max_dist = -1
|
|
|
|
|
b_indices = np.argsort([not label_matcher(a_segm, b_segm)
|
|
|
|
|
for b_segm in b_segms],
|
|
|
|
|
kind='stable') # prioritize those with same label, keep score order
|
|
|
|
|
for b_idx in b_indices:
|
|
|
|
|
if 0 <= b_matches[b_idx]: # assign a_segm with max conf
|
|
|
|
|
continue
|
|
|
|
|
d = distances[a_idx, b_idx]
|
|
|
|
|
if d < max_dist:
|
|
|
|
|
if d < dist_thresh or d <= max_dist:
|
|
|
|
|
continue
|
|
|
|
|
max_dist = d
|
|
|
|
|
matched_b = b_idx
|
|
|
|
|
@ -771,7 +930,7 @@ def match_segments(a_segms, b_segms, distance='iou', dist_thresh=1.0):
|
|
|
|
|
|
|
|
|
|
b_segm = b_segms[matched_b]
|
|
|
|
|
|
|
|
|
|
if a_segm.label == b_segm.label:
|
|
|
|
|
if label_matcher(a_segm, b_segm):
|
|
|
|
|
matches.append( (a_segm, b_segm) )
|
|
|
|
|
else:
|
|
|
|
|
mispred.append( (a_segm, b_segm) )
|
|
|
|
|
|