[Datumaro] Add merge command with segment intersection (#1695)

* Add multi source merge

* update changelog

* cli update

* linter

* fixes and tests

* fix test

* fix test

* relax type requirements in annotations

* fix polylines

* Make groups more stable

* Add group checks

* add group check test
main
Maxim Zhiltsov 6 years ago committed by GitHub
parent 90cc36eb1c
commit 17a5554cd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Siammask tracker as DL serverless function (<https://github.com/opencv/cvat/pull/1988>)
- [Datumaro] Added model info and source info commands (<https://github.com/opencv/cvat/pull/1973>)
- [Datumaro] Dataset statistics (<https://github.com/opencv/cvat/pull/1668>)
- [Datumaro] Multi-dataset merge (https://github.com/opencv/cvat/pull/1695)
### Changed
- Shape coordinates are rounded to 2 digits in dumped annotations (<https://github.com/opencv/cvat/pull/1970>)

@ -68,6 +68,7 @@ def make_parser():
('remove', commands.remove, "Remove source from project"),
('export', commands.export, "Export project"),
('explain', commands.explain, "Run Explainable AI algorithm for model"),
('merge', commands.merge, "Merge datasets"),
('convert', commands.convert, "Convert dataset"),
]

@ -3,4 +3,4 @@
#
# SPDX-License-Identifier: MIT
from . import add, create, explain, export, remove, convert
from . import add, create, explain, export, remove, merge, convert

@ -0,0 +1,124 @@
# Copyright (C) 2020 Intel Corporation
#
# SPDX-License-Identifier: MIT
import argparse
import json
import logging as log
import os.path as osp
from collections import OrderedDict
from datumaro.components.project import Project
from datumaro.components.operations import (IntersectMerge,
QualityError, MergeError)
from ..util import at_least, MultilineFormatter, CliException
from ..util.project import generate_next_file_name, load_project
def build_parser(parser_ctor=argparse.ArgumentParser):
parser = parser_ctor(help="Merge few projects",
description="""
Merges multiple datasets into one. This can be useful if you
have few annotations and wish to merge them,
taking into consideration potential overlaps and conflicts.
This command can try to find a common ground by voting or
return a list of conflicts.|n
|n
Examples:|n
- Merge annotations from 3 (or more) annotators:|n
|s|smerge project1/ project2/ project3/|n
- Check groups of the merged dataset for consistence:|n
|s|s|slook for groups consising of 'person', 'hand' 'head', 'foot'|n
|s|smerge project1/ project2/ -g 'person,hand?,head,foot?'
""",
formatter_class=MultilineFormatter)
def _group(s):
return s.split(',')
parser.add_argument('project', nargs='+', action=at_least(2),
help="Path to a project (repeatable)")
parser.add_argument('-iou', '--iou-thresh', default=0.25, type=float,
help="IoU match threshold for segments (default: %(default)s)")
parser.add_argument('-oconf', '--output-conf-thresh',
default=0.0, type=float,
help="Confidence threshold for output "
"annotations (default: %(default)s)")
parser.add_argument('--quorum', default=0, type=int,
help="Minimum count for a label and attribute voting "
"results to be counted (default: %(default)s)")
parser.add_argument('-g', '--groups', action='append', type=_group,
default=[],
help="A comma-separated list of labels in "
"annotation groups to check. '?' postfix can be added to a label to"
"make it optional in the group (repeatable)")
parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None,
help="Output directory (default: current project's dir)")
parser.add_argument('--overwrite', action='store_true',
help="Overwrite existing files in the save directory")
parser.set_defaults(command=merge_command)
return parser
def merge_command(args):
source_projects = [load_project(p) for p in args.project]
dst_dir = args.dst_dir
if dst_dir:
if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir):
raise CliException("Directory '%s' already exists "
"(pass --overwrite to overwrite)" % dst_dir)
else:
dst_dir = generate_next_file_name('merged')
source_datasets = []
for p in source_projects:
log.debug("Loading project '%s' dataset", p.config.project_name)
source_datasets.append(p.make_dataset())
merger = IntersectMerge(conf=IntersectMerge.Conf(
pairwise_dist=args.iou_thresh, groups=args.groups,
output_conf_thresh=args.output_conf_thresh, quorum=args.quorum
))
merged_dataset = merger(source_datasets)
merged_project = Project()
output_dataset = merged_project.make_dataset()
output_dataset.define_categories(merged_dataset.categories())
merged_dataset = output_dataset.update(merged_dataset)
merged_dataset.save(save_dir=dst_dir)
report_path = osp.join(dst_dir, 'merge_report.json')
save_merge_report(merger, report_path)
dst_dir = osp.abspath(dst_dir)
log.info("Merge results have been saved to '%s'" % dst_dir)
log.info("Report has been saved to '%s'" % report_path)
return 0
def save_merge_report(merger, path):
item_errors = OrderedDict()
source_errors = OrderedDict()
all_errors = []
for e in merger.errors:
if isinstance(e, QualityError):
item_errors[str(e.item_id)] = item_errors.get(str(e.item_id), 0) + 1
elif isinstance(e, MergeError):
for s in e.sources:
source_errors[s] = source_errors.get(s, 0) + 1
item_errors[str(e.item_id)] = item_errors.get(str(e.item_id), 0) + 1
all_errors.append(str(e))
errors = OrderedDict([
('Item errors', item_errors),
('Source errors', source_errors),
('All errors', all_errors),
])
with open(path, 'w') as f:
json.dump(errors, f, indent=4)

@ -37,6 +37,28 @@ class MultilineFormatter(argparse.HelpFormatter):
multiline_text += formatted_paragraph
return multiline_text
def required_count(nmin=0, nmax=0):
assert 0 <= nmin and 0 <= nmax and nmin or nmax
class RequiredCount(argparse.Action):
def __call__(self, parser, args, values, option_string=None):
k = len(values)
if not ((nmin and (nmin <= k) or not nmin) and \
(nmax and (k <= nmax) or not nmax)):
msg = "Argument '%s' requires" % self.dest
if nmin and nmax:
msg += " from %s to %s arguments" % (nmin, nmax)
elif nmin:
msg += " at least %s arguments" % nmin
else:
msg += " no more %s arguments" % nmax
raise argparse.ArgumentTypeError(msg)
setattr(args, self.dest, values)
return RequiredCount
def at_least(n):
return required_count(n, 0)
def make_file_name(s):
# adapted from
# https://docs.djangoproject.com/en/2.1/_modules/django/utils/text/#slugify

@ -9,6 +9,7 @@ import numpy as np
from math import ceil
from datumaro.components.extractor import AnnotationType
from datumaro.util.annotation_util import nms
def flatmatvec(mat):
@ -51,24 +52,6 @@ class RISE:
bboxes.append(r)
return labels, bboxes
@staticmethod
def nms(boxes, iou_thresh=0.5):
indices = np.argsort([b.attributes['score'] for b in boxes])
ious = np.array([[a.iou(b) for b in boxes] for a in boxes])
predictions = []
while len(indices) != 0:
i = len(indices) - 1
pred_idx = indices[i]
to_remove = [i]
predictions.append(boxes[pred_idx])
for i, box_idx in enumerate(indices[:i]):
if iou_thresh < ious[pred_idx, box_idx]:
to_remove.append(i)
indices = np.delete(indices, to_remove)
return predictions
def normalize_hmaps(self, heatmaps, counts):
eps = np.finfo(heatmaps.dtype).eps
mhmaps = flatmatvec(heatmaps)
@ -106,7 +89,7 @@ class RISE:
result_bboxes = [b for b in result_bboxes \
if self.det_conf_thresh <= b.attributes['score']]
if 0 < self.nms_thresh:
result_bboxes = self.nms(result_bboxes, self.nms_thresh)
result_bboxes = nms(result_bboxes, self.nms_thresh)
predicted_labels = set()
if len(result_labels) != 0:
@ -194,7 +177,7 @@ class RISE:
result_bboxes = [b for b in result_bboxes \
if self.det_conf_thresh <= b.attributes['score']]
if 0 < self.nms_thresh:
result_bboxes = self.nms(result_bboxes, self.nms_thresh)
result_bboxes = nms(result_bboxes, self.nms_thresh)
for detection in result_bboxes:
for pred_idx, pred in enumerate(predicted_bboxes):
@ -202,7 +185,7 @@ class RISE:
continue
iou = pred.iou(detection)
assert 0 <= iou and iou <= 1
assert iou == -1 or 0 <= iou and iou <= 1
if iou < iou_thresh:
continue

@ -13,6 +13,7 @@ from attr import attrs, attrib
from datumaro.util.image import Image
from datumaro.util.attrs_util import not_empty, default_if_none
AnnotationType = Enum('AnnotationType',
[
'label',
@ -28,9 +29,9 @@ _COORDINATE_ROUNDING_DIGITS = 2
@attrs
class Annotation:
id = attrib(converter=int, default=0, kw_only=True)
attributes = attrib(converter=dict, factory=dict, kw_only=True)
group = attrib(converter=int, default=0, kw_only=True)
id = attrib(default=0, validator=default_if_none(int), kw_only=True)
attributes = attrib(factory=dict, validator=default_if_none(dict), kw_only=True)
group = attrib(default=0, validator=default_if_none(int), kw_only=True)
def __attrs_post_init__(self):
assert isinstance(self.type, AnnotationType)
@ -92,7 +93,7 @@ class LabelCategories(Categories):
self._indices = indices
def add(self, name, parent=None, attributes=None):
assert name not in self._indices
assert name not in self._indices, name
if attributes is None:
attributes = set()
else:
@ -110,7 +111,7 @@ class LabelCategories(Categories):
def find(self, name):
index = self._indices.get(name)
if index:
if index is not None:
return index, self.items[index]
return index, None
@ -148,7 +149,7 @@ class Mask(Annotation):
_image = attrib()
label = attrib(converter=attr.converters.optional(int),
default=None, kw_only=True)
z_order = attrib(converter=int, default=0, kw_only=True)
z_order = attrib(default=0, validator=default_if_none(int), kw_only=True)
@property
def image(self):
@ -274,31 +275,13 @@ class CompiledMask:
def lazy_extract(self, instance_id):
return lambda: self.extract(instance_id)
def compute_iou(bbox_a, bbox_b):
aX, aY, aW, aH = bbox_a
bX, bY, bW, bH = bbox_b
in_right = min(aX + aW, bX + bW)
in_left = max(aX, bX)
in_top = max(aY, bY)
in_bottom = min(aY + aH, bY + bH)
in_w = max(0, in_right - in_left)
in_h = max(0, in_bottom - in_top)
intersection = in_w * in_h
a_area = aW * aH
b_area = bW * bH
union = a_area + b_area - intersection
return intersection / max(1.0, union)
@attrs
class _Shape(Annotation):
points = attrib(converter=lambda x:
[round(p, _COORDINATE_ROUNDING_DIGITS) for p in x])
label = attrib(converter=attr.converters.optional(int),
default=None, kw_only=True)
z_order = attrib(converter=int, default=0, kw_only=True)
z_order = attrib(default=0, validator=default_if_none(int), kw_only=True)
def get_area(self):
raise NotImplementedError()
@ -386,7 +369,8 @@ class Bbox(_Shape):
]
def iou(self, other):
return compute_iou(self.get_bbox(), other.get_bbox())
from datumaro.util.annotation_util import bbox_iou
return bbox_iou(self.get_bbox(), other.get_bbox())
def wrap(item, **kwargs):
d = {'x': item.x, 'y': item.y, 'w': item.w, 'h': item.h}

@ -3,14 +3,782 @@
#
# SPDX-License-Identifier: MIT
import logging as log
from collections import OrderedDict
from copy import deepcopy
import logging as log
import attr
import cv2
import numpy as np
from attr import attrib, attrs
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.extractor import AnnotationType, Bbox, Label
from datumaro.components.project import Dataset
from datumaro.util import find
from datumaro.util.attrs_util import ensure_cls
from datumaro.util.annotation_util import (segment_iou, bbox_iou,
mean_bbox, OKS, find_instances, max_bbox, smooth_line)
def get_ann_type(anns, t):
return [a for a in anns if a.type == t]
def match_annotations_equal(a, b):
matches = []
a_unmatched = a[:]
b_unmatched = b[:]
for a_ann in a:
for b_ann in b_unmatched:
if a_ann != b_ann:
continue
matches.append((a_ann, b_ann))
a_unmatched.remove(a_ann)
b_unmatched.remove(b_ann)
break
return matches, a_unmatched, b_unmatched
def merge_annotations_equal(a, b):
matches, a_unmatched, b_unmatched = match_annotations_equal(a, b)
return [ann_a for (ann_a, _) in matches] + a_unmatched + b_unmatched
def merge_categories(sources):
categories = {}
for source in sources:
categories.update(source)
for source in sources:
for cat_type, source_cat in source.items():
if not categories[cat_type] == source_cat:
raise NotImplementedError(
"Merging different categories is not implemented yet")
return categories
class MergingStrategy(CliPlugin):
@classmethod
def merge(cls, sources, **options):
instance = cls(**options)
return instance(sources)
def __init__(self, **options):
super().__init__(**options)
self.__dict__['_sources'] = None
def __call__(self, sources):
raise NotImplementedError()
@attrs
class DatasetError:
item_id = attrib()
@attrs
class QualityError(DatasetError):
pass
@attrs
class TooCloseError(QualityError):
a = attrib()
b = attrib()
distance = attrib()
def __str__(self):
return "Item %s: annotations are too close: %s, %s, distance = %s" % \
(self.item_id, self.a, self.b, self.distance)
@attrs
class WrongGroupError(QualityError):
found = attrib(converter=set)
expected = attrib(converter=set)
group = attrib(converter=list)
def __str__(self):
return "Item %s: annotation group has wrong labels: " \
"found %s, expected %s, group %s" % \
(self.item_id, self.found, self.expected, self.group)
@attrs
class MergeError(DatasetError):
sources = attrib(converter=set)
@attrs
class NoMatchingAnnError(MergeError):
ann = attrib()
def __str__(self):
return "Item %s: can't find matching annotation " \
"in sources %s, annotation is %s" % \
(self.item_id, self.sources, self.ann)
@attrs
class NoMatchingItemError(MergeError):
def __str__(self):
return "Item %s: can't find matching item in sources %s" % \
(self.item_id, self.sources)
@attrs
class FailedLabelVotingError(MergeError):
votes = attrib()
ann = attrib(default=None)
def __str__(self):
return "Item %s: label voting failed%s, votes %s, sources %s" % \
(self.item_id, 'for ann %s' % self.ann if self.ann else '',
self.votes, self.sources)
@attrs
class FailedAttrVotingError(MergeError):
attr = attrib()
votes = attrib()
ann = attrib()
def __str__(self):
return "Item %s: attribute voting failed " \
"for ann %s, votes %s, sources %s" % \
(self.item_id, self.ann, self.votes, self.sources)
@attrs
class IntersectMerge(MergingStrategy):
@attrs(repr_ns='IntersectMerge', kw_only=True)
class Conf:
pairwise_dist = attrib(converter=float, default=0.5)
sigma = attrib(converter=list, factory=list)
output_conf_thresh = attrib(converter=float, default=0)
quorum = attrib(converter=int, default=0)
ignored_attributes = attrib(converter=set, factory=set)
def _groups_conveter(value):
result = []
for group in value:
rg = set()
for label in group:
optional = label.endswith('?')
name = label if not optional else label[:-1]
rg.add((name, optional))
result.append(rg)
return result
groups = attrib(converter=_groups_conveter, factory=list)
close_distance = attrib(converter=float, default=0.75)
conf = attrib(converter=ensure_cls(Conf), factory=Conf)
# Error trackers:
errors = attrib(factory=list, init=False)
def add_item_error(self, error, *args, **kwargs):
self.errors.append(error(self._item_id, *args, **kwargs))
# Indexes:
_dataset_map = attrib(init=False) # id(dataset) -> (dataset, index)
_item_map = attrib(init=False) # id(item) -> (item, id(dataset))
_ann_map = attrib(init=False) # id(ann) -> (ann, id(item))
_item_id = attrib(init=False)
_item = attrib(init=False)
# Misc.
_categories = attrib(init=False) # merged categories
def __call__(self, datasets):
self._categories = merge_categories(d.categories() for d in datasets)
merged = Dataset(categories=self._categories)
self._check_groups_definition()
item_matches, item_map = self.match_items(datasets)
self._item_map = item_map
self._dataset_map = { id(d): (d, i) for i, d in enumerate(datasets) }
for item_id, items in item_matches.items():
self._item_id = item_id
if len(items) < len(datasets):
missing_sources = set(id(s) for s in datasets) - set(items)
missing_sources = [self._dataset_map[s][1]
for s in missing_sources]
self.add_item_error(NoMatchingItemError, missing_sources)
merged.put(self.merge_items(items))
return merged
def get_ann_source(self, ann_id):
return self._item_map[self._ann_map[ann_id][1]][1]
def merge_items(self, items):
self._item = next(iter(items.values()))
self._ann_map = {}
sources = []
for item in items.values():
self._ann_map.update({ id(a): (a, id(item))
for a in item.annotations })
sources.append(item.annotations)
log.debug("Merging item %s: source annotations %s" % \
(self._item_id, list(map(len, sources))))
annotations = self.merge_annotations(sources)
annotations = [a for a in annotations
if self.conf.output_conf_thresh <= a.attributes.get('score', 1)]
return self._item.wrap(annotations=annotations)
def merge_annotations(self, sources):
self._make_mergers(sources)
clusters = self._match_annotations(sources)
joined_clusters = sum(clusters.values(), [])
group_map = self._find_cluster_groups(joined_clusters)
annotations = []
for t, clusters in clusters.items():
for cluster in clusters:
self._check_cluster_sources(cluster)
merged_clusters = self._merge_clusters(t, clusters)
for merged_ann, cluster in zip(merged_clusters, clusters):
attributes = self._find_cluster_attrs(cluster, merged_ann)
attributes = { k: v for k, v in attributes.items()
if k not in self.conf.ignored_attributes }
attributes.update(merged_ann.attributes)
merged_ann.attributes = attributes
new_group_id = find(enumerate(group_map),
lambda e: id(cluster) in e[1][0])
if new_group_id is None:
new_group_id = 0
else:
new_group_id = new_group_id[0] + 1
merged_ann.group = new_group_id
if self.conf.close_distance:
self._check_annotation_distance(t, merged_clusters)
annotations += merged_clusters
if self.conf.groups:
self._check_groups(annotations)
return annotations
@staticmethod
def match_items(datasets):
item_ids = set((item.id, item.subset) for d in datasets for item in d)
item_map = {} # id(item) -> (item, id(dataset))
matches = OrderedDict()
for (item_id, item_subset) in sorted(item_ids, key=lambda e: e[0]):
items = {}
for d in datasets:
try:
item = d.get(item_id, subset=item_subset)
items[id(d)] = item
item_map[id(item)] = (item, id(d))
except KeyError:
pass
matches[(item_id, item_subset)] = items
return matches, item_map
def _match_annotations(self, sources):
all_by_type = {}
for s in sources:
src_by_type = {}
for a in s:
src_by_type.setdefault(a.type, []).append(a)
for k, v in src_by_type.items():
all_by_type.setdefault(k, []).append(v)
clusters = {}
for k, v in all_by_type.items():
clusters.setdefault(k, []).extend(self._match_ann_type(k, v))
return clusters
def _make_mergers(self, sources):
def _make(c, **kwargs):
kwargs.update(attr.asdict(self.conf))
fields = attr.fields_dict(c)
return c(**{ k: v for k, v in kwargs.items() if k in fields },
context=self)
def _for_type(t, **kwargs):
if t is AnnotationType.label:
return _make(LabelMerger, **kwargs)
elif t is AnnotationType.bbox:
return _make(BboxMerger, **kwargs)
elif t is AnnotationType.mask:
return _make(MaskMerger, **kwargs)
elif t is AnnotationType.polygon:
return _make(PolygonMerger, **kwargs)
elif t is AnnotationType.polyline:
return _make(LineMerger, **kwargs)
elif t is AnnotationType.points:
return _make(PointsMerger, **kwargs)
elif t is AnnotationType.caption:
return _make(CaptionsMerger, **kwargs)
else:
raise NotImplementedError("Type %s is not supported" % t)
instance_map = {}
for s in sources:
s_instances = find_instances(s)
for inst in s_instances:
inst_bbox = max_bbox([a for a in inst if a.type in
{AnnotationType.polygon,
AnnotationType.mask, AnnotationType.bbox}
])
for ann in inst:
instance_map[id(ann)] = [inst, inst_bbox]
self._mergers = { t: _for_type(t, instance_map=instance_map)
for t in AnnotationType }
def _match_ann_type(self, t, sources):
return self._mergers[t].match_annotations(sources)
def _merge_clusters(self, t, clusters):
return self._mergers[t].merge_clusters(clusters)
@staticmethod
def _find_cluster_groups(clusters):
cluster_groups = []
visited = set()
for a_idx, cluster_a in enumerate(clusters):
if a_idx in visited:
continue
visited.add(a_idx)
cluster_group = { id(cluster_a) }
# find segment groups in the cluster group
a_groups = set(ann.group for ann in cluster_a)
for cluster_b in clusters[a_idx+1 :]:
b_groups = set(ann.group for ann in cluster_b)
if a_groups & b_groups:
a_groups |= b_groups
# now we know all the segment groups in this cluster group
# so we can find adjacent clusters
for b_idx, cluster_b in enumerate(clusters[a_idx+1 :]):
b_idx = a_idx + 1 + b_idx
b_groups = set(ann.group for ann in cluster_b)
if a_groups & b_groups:
cluster_group.add( id(cluster_b) )
visited.add(b_idx)
if a_groups == {0}:
continue # skip annotations without a group
cluster_groups.append( (cluster_group, a_groups) )
return cluster_groups
def _find_cluster_attrs(self, cluster, ann):
quorum = self.conf.quorum or 0
# TODO: when attribute types are implemented, add linear
# interpolation for contiguous values
attr_votes = {} # name -> { value: score , ... }
for s in cluster:
for name, value in s.attributes.items():
votes = attr_votes.get(name, {})
votes[value] = 1 + votes.get(value, 0)
attr_votes[name] = votes
attributes = {}
for name, votes in attr_votes.items():
winner, count = max(votes.items(), key=lambda e: e[1])
if count < quorum:
if sum(votes.values()) < quorum:
# blame provokers
missing_sources = set(
self.get_ann_source(id(a)) for a in cluster
if s.attributes.get(name) == winner)
else:
# blame outliers
missing_sources = set(
self.get_ann_source(id(a)) for a in cluster
if s.attributes.get(name) != winner)
missing_sources = [self._dataset_map[s][1]
for s in missing_sources]
self.add_item_error(FailedAttrVotingError,
missing_sources, name, votes, ann)
continue
attributes[name] = winner
return attributes
def _check_cluster_sources(self, cluster):
if len(cluster) == len(self._dataset_map):
return
def _has_item(s):
try:
item =self._dataset_map[s][0].get(*self._item_id)
if len(item.annotations) == 0:
return False
return True
except KeyError:
return False
missing_sources = set(self._dataset_map) - \
set(self.get_ann_source(id(a)) for a in cluster)
missing_sources = [self._dataset_map[s][1] for s in missing_sources
if _has_item(s)]
if missing_sources:
self.add_item_error(NoMatchingAnnError, missing_sources, cluster[0])
def _check_annotation_distance(self, t, annotations):
for a_idx, a_ann in enumerate(annotations):
for b_ann in annotations[a_idx+1:]:
d = self._mergers[t].distance(a_ann, b_ann)
if self.conf.close_distance < d:
self.add_item_error(TooCloseError, a_ann, b_ann, d)
def _check_groups(self, annotations):
check_groups = []
for check_group_raw in self.conf.groups:
check_group = set(l[0] for l in check_group_raw)
optional = set(l[0] for l in check_group_raw if l[1])
check_groups.append((check_group, optional))
def _check_group(group_labels, group):
for check_group, optional in check_groups:
common = check_group & group_labels
real_miss = check_group - common - optional
extra = group_labels - check_group
if common and (extra or real_miss):
self.add_item_error(WrongGroupError, group_labels,
check_group, group)
break
groups = find_instances(annotations)
for group in groups:
group_labels = set()
for ann in group:
if not hasattr(ann, 'label'):
continue
label = self._get_label_name(ann.label)
if ann.group:
group_labels.add(label)
else:
_check_group({label}, [ann])
if not group_labels:
continue
_check_group(group_labels, group)
def _get_label_name(self, label_id):
return self._categories[AnnotationType.label].items[label_id].name
def _check_groups_definition(self):
for group in self.conf.groups:
for label, _ in group:
_, entry = self._categories[AnnotationType.label].find(label)
if entry is None:
raise ValueError("Datasets do not contain "
"label '%s', available labels %s" % \
(label, [i.name for i in
self._categories[AnnotationType.label].items])
)
@attrs
class AnnotationMatcher:
def match_annotations(self, sources):
raise NotImplementedError()
@attrs
class LabelMatcher(AnnotationMatcher):
@staticmethod
def distance(a, b):
return a.label == b.label
def match_annotations(self, sources):
return [sum(sources, [])]
@attrs(kw_only=True)
class _ShapeMatcher(AnnotationMatcher):
pairwise_dist = attrib(converter=float, default=0.9)
cluster_dist = attrib(converter=float, default=-1.0)
def match_annotations(self, sources):
distance = self.distance
pairwise_dist = self.pairwise_dist
cluster_dist = self.cluster_dist
if cluster_dist < 0: cluster_dist = pairwise_dist
id_segm = { id(a): (a, id(s)) for s in sources for a in s }
def _is_close_enough(cluster, extra_id):
# check if whole cluster IoU will not be broken
# when this segment is added
b = id_segm[extra_id][0]
for a_id in cluster:
a = id_segm[a_id][0]
if distance(a, b) < cluster_dist:
return False
return True
def _has_same_source(cluster, extra_id):
b = id_segm[extra_id][1]
for a_id in cluster:
a = id_segm[a_id][1]
if a == b:
return True
return False
# match segments in sources, pairwise
adjacent = { i: [] for i in id_segm } # id(sgm) -> [id(adj_sgm1), ...]
for a_idx, src_a in enumerate(sources):
for src_b in sources[a_idx+1 :]:
matches, _, _, _ = match_segments(src_a, src_b,
dist_thresh=pairwise_dist, distance=distance)
for m in matches:
adjacent[id(m[0])].append(id(m[1]))
# join all segments into matching clusters
clusters = []
visited = set()
for cluster_idx in adjacent:
if cluster_idx in visited:
continue
cluster = set()
to_visit = { cluster_idx }
while to_visit:
c = to_visit.pop()
cluster.add(c)
visited.add(c)
for i in adjacent[c]:
if i in visited:
continue
if 0 < cluster_dist and not _is_close_enough(cluster, i):
continue
if _has_same_source(cluster, i):
continue
to_visit.add(i)
clusters.append([id_segm[i][0] for i in cluster])
return clusters
@staticmethod
def distance(a, b):
return segment_iou(a, b)
@attrs
class BboxMatcher(_ShapeMatcher):
pass
@attrs
class PolygonMatcher(_ShapeMatcher):
pass
@attrs
class MaskMatcher(_ShapeMatcher):
pass
@attrs(kw_only=True)
class PointsMatcher(_ShapeMatcher):
sigma = attrib(converter=list, default=None)
instance_map = attrib(converter=dict)
def distance(self, a, b):
a_bbox = self.instance_map[id(a)][1]
b_bbox = self.instance_map[id(b)][1]
if bbox_iou(a_bbox, b_bbox) <= 0:
return 0
bbox = mean_bbox([a_bbox, b_bbox])
return OKS(a, b, sigma=self.sigma, bbox=bbox)
@attrs
class LineMatcher(_ShapeMatcher):
@staticmethod
def distance(a, b):
a_bbox = a.get_bbox()
b_bbox = b.get_bbox()
bbox = max_bbox([a_bbox, b_bbox])
area = bbox[2] * bbox[3]
if not area:
return 1
# compute inter-line area, normalize by common bbox
point_count = max(max(len(a.points) // 2, len(b.points) // 2), 5)
a, sa = smooth_line(a.points, point_count)
b, sb = smooth_line(b.points, point_count)
dists = np.linalg.norm(a - b, axis=1)
dists = (dists[:-1] + dists[1:]) * 0.5
s = np.sum(dists) * 0.5 * (sa + sb) / area
return abs(1 - s)
@attrs
class CaptionsMatcher(AnnotationMatcher):
def match_annotations(self, sources):
raise NotImplementedError()
@attrs(kw_only=True)
class AnnotationMerger:
_context = attrib(type=IntersectMerge, default=None)
def merge_clusters(self, clusters):
raise NotImplementedError()
@attrs(kw_only=True)
class LabelMerger(AnnotationMerger, LabelMatcher):
quorum = attrib(converter=int, default=0)
def merge_clusters(self, clusters):
assert len(clusters) <= 1
if len(clusters) == 0:
return []
votes = {} # label -> score
for label_ann in clusters[0]:
votes[label_ann.label] = 1 + votes.get(label_ann.label, 0)
merged = []
for label, count in votes.items():
if count < self.quorum:
sources = set(self.get_ann_source(id(a)) for a in clusters[0]
if label not in [l.label for l in a])
sources = [self._context._dataset_map[s][1] for s in sources]
self._context.add_item_error(FailedLabelVotingError,
sources, votes)
continue
merged.append(Label(label, attributes={
'score': count / len(self._context._dataset_map)
}))
return merged
@attrs(kw_only=True)
class _ShapeMerger(AnnotationMerger, _ShapeMatcher):
quorum = attrib(converter=int, default=0)
def merge_clusters(self, clusters):
merged = []
for cluster in clusters:
label, label_score = self.find_cluster_label(cluster)
shape, shape_score = self.merge_cluster_shape(cluster)
shape.z_order = max(cluster, key=lambda a: a.z_order).z_order
shape.label = label
shape.attributes['score'] = label_score * shape_score \
if label is not None else shape_score
merged.append(shape)
return merged
def find_cluster_label(self, cluster):
votes = {}
for s in cluster:
state = votes.setdefault(s.label, [0, 0])
state[0] += s.attributes.get('score', 1.0)
state[1] += 1
label, (score, count) = max(votes.items(), key=lambda e: e[1][0])
if count < self.quorum:
self._context.add_item_error(FailedLabelVotingError, votes)
score = score / count if count else None
return label, score
@staticmethod
def _merge_cluster_shape_mean_box_nearest(cluster):
mbbox = Bbox(*mean_bbox(cluster))
dist = (segment_iou(mbbox, s) for s in cluster)
nearest_pos, _ = max(enumerate(dist), key=lambda e: e[1])
return cluster[nearest_pos]
def merge_cluster_shape(self, cluster):
shape = self._merge_cluster_shape_mean_box_nearest(cluster)
shape_score = sum(max(0, self.distance(shape, s))
for s in cluster) / len(cluster)
return shape, shape_score
@attrs
class BboxMerger(_ShapeMerger, BboxMatcher):
pass
@attrs
class PolygonMerger(_ShapeMerger, PolygonMatcher):
pass
@attrs
class MaskMerger(_ShapeMerger, MaskMatcher):
pass
@attrs
class PointsMerger(_ShapeMerger, PointsMatcher):
pass
@attrs
class LineMerger(_ShapeMerger, LineMatcher):
pass
@attrs
class CaptionsMerger(AnnotationMerger, CaptionsMatcher):
pass
def match_segments(a_segms, b_segms, distance='iou', dist_thresh=1.0):
if distance == 'iou':
distance = segment_iou
else:
assert callable(distance)
a_segms.sort(key=lambda ann: 1 - ann.attributes.get('score', 1))
b_segms.sort(key=lambda ann: 1 - ann.attributes.get('score', 1))
# a_matches: indices of b_segms matched to a bboxes
# b_matches: indices of a_segms matched to b bboxes
a_matches = -np.ones(len(a_segms), dtype=int)
b_matches = -np.ones(len(b_segms), dtype=int)
distances = np.array([[distance(a, b) for b in b_segms] for a in a_segms])
# matches: boxes we succeeded to match completely
# mispred: boxes we succeeded to match, having label mismatch
matches = []
mispred = []
for a_idx, a_segm in enumerate(a_segms):
if len(b_segms) == 0:
break
matched_b = a_matches[a_idx]
max_dist = max(distances[a_idx, matched_b], dist_thresh)
for b_idx, b_segm in enumerate(b_segms):
if 0 <= b_matches[b_idx]: # assign a_segm with max conf
continue
d = distances[a_idx, b_idx]
if d < max_dist:
continue
max_dist = d
matched_b = b_idx
if matched_b < 0:
continue
a_matches[a_idx] = matched_b
b_matches[matched_b] = a_idx
b_segm = b_segms[matched_b]
if a_segm.label == b_segm.label:
matches.append( (a_segm, b_segm) )
else:
mispred.append( (a_segm, b_segm) )
from datumaro.components.extractor import AnnotationType
# *_umatched: boxes of (*) we failed to match
a_unmatched = [a_segms[i] for i, m in enumerate(a_matches) if m < 0]
b_unmatched = [b_segms[i] for i, m in enumerate(b_matches) if m < 0]
return matches, mispred, a_unmatched, b_unmatched
def mean_std(dataset):
"""

@ -351,16 +351,7 @@ class Dataset(Extractor):
@classmethod
def from_extractors(cls, *sources):
# merge categories
# TODO: implement properly with merging and annotations remapping
categories = {}
for source in sources:
categories.update(source.categories())
for source in sources:
for cat_type, source_cat in source.categories().items():
if not categories[cat_type] == source_cat:
raise NotImplementedError(
"Merging different categories is not implemented yet")
categories = cls._merge_categories(s.categories() for s in sources)
dataset = Dataset(categories=categories)
# merge items
@ -457,7 +448,7 @@ class Dataset(Extractor):
@classmethod
def _merge_items(cls, existing_item, current_item, path=None):
return existing_item.wrap(path=path,
image=cls._merge_images(existing_item, current_item),
image=cls._merge_images(existing_item, current_item),
annotations=cls._merge_anno(
existing_item.annotations, current_item.annotations))
@ -489,18 +480,15 @@ class Dataset(Extractor):
@staticmethod
def _merge_anno(a, b):
from itertools import chain
merged = []
for item in chain(a, b):
found = False
for elem in merged:
if elem == item:
found = True
break
if not found:
merged.append(item)
return merged
# TODO: implement properly with merging and annotations remapping
from .operations import merge_annotations_equal
return merge_annotations_equal(a, b)
@staticmethod
def _merge_categories(sources):
# TODO: implement properly with merging and annotations remapping
from .operations import merge_categories
return merge_categories(sources)
class ProjectDataset(Dataset):
def __init__(self, project):
@ -535,14 +523,9 @@ class ProjectDataset(Dataset):
# merge categories
# TODO: implement properly with merging and annotations remapping
categories = {}
for source in self._sources.values():
categories.update(source.categories())
for source in self._sources.values():
for cat_type, source_cat in source.categories().items():
if not categories[cat_type] == source_cat:
raise NotImplementedError(
"Merging different categories is not implemented yet")
categories = self._merge_categories(s.categories()
for s in self._sources.values())
# ovewrite with own categories
if own_source is not None and (not categories or len(own_source) != 0):
categories.update(own_source.categories())
self._categories = categories

@ -9,7 +9,7 @@ import_tf() # prevent TF loading and potential interpeter crash
import accuracy_checker.representation as ac
import datumaro.components.extractor as dm
from datumaro.util.annotation_tools import softmax
from datumaro.util.annotation_util import softmax
def import_predictions(predictions):
# Convert Accuracy checker predictions to Datumaro annotations

@ -12,7 +12,7 @@ from itertools import groupby
import pycocotools.mask as mask_utils
import datumaro.util.annotation_tools as anno_tools
import datumaro.util.annotation_util as anno_tools
import datumaro.util.mask_tools as mask_tools
from datumaro.components.converter import Converter
from datumaro.components.extractor import (_COORDINATE_ROUNDING_DIGITS,
@ -202,7 +202,7 @@ class _InstancesConverter(_TaskConverter):
anns = boxes + polygons + masks
leader = anno_tools.find_group_leader(anns)
bbox = anno_tools.compute_bbox(anns)
bbox = anno_tools.max_bbox(anns)
mask = None
polygons = [p.points for p in polygons]

@ -11,7 +11,7 @@ from xml.sax.saxutils import XMLGenerator
from datumaro.components.converter import Converter
from datumaro.components.extractor import DEFAULT_SUBSET_NAME, AnnotationType
from datumaro.util import cast, pairwise
from datumaro.util import cast, pairs
from .format import CvatPath
@ -246,7 +246,7 @@ class _SubsetWriter:
','.join((
"{:.2f}".format(x),
"{:.2f}".format(y)
)) for x, y in pairwise(shape.points))
)) for x, y in pairs(shape.points))
)),
]))

@ -16,7 +16,7 @@ from datumaro.components.extractor import (AnnotationType, DEFAULT_SUBSET_NAME,
)
from datumaro.components.converter import Converter
from datumaro.util.image import encode_image
from datumaro.util.annotation_tools import (compute_bbox,
from datumaro.util.annotation_util import (max_bbox,
find_group_leader, find_instances)
from datumaro.util.mask_tools import merge_masks
from datumaro.util.tf_util import import_tf as _import_tf
@ -111,7 +111,7 @@ class TfDetectionApiConverter(Converter):
anns = boxes + masks
leader = find_group_leader(anns)
bbox = compute_bbox(anns)
bbox = max_bbox(anns)
mask = None
if self._save_masks:

@ -17,7 +17,7 @@ from datumaro.components.extractor import (Transform, AnnotationType,
)
from datumaro.components.cli_plugin import CliPlugin
import datumaro.util.mask_tools as mask_tools
from datumaro.util.annotation_tools import find_group_leader, find_instances
from datumaro.util.annotation_util import find_group_leader, find_instances
class CropCoveredSegments(Transform, CliPlugin):

@ -62,7 +62,7 @@ def to_snake_case(s):
name.append(char)
return ''.join(name)
def pairwise(iterable):
def pairs(iterable):
a = iter(iterable)
return zip(a, a)

@ -1,34 +0,0 @@
# Copyright (C) 2020 Intel Corporation
#
# SPDX-License-Identifier: MIT
from itertools import groupby
import numpy as np
def find_instances(instance_anns):
instance_anns = sorted(instance_anns, key=lambda a: a.group)
ann_groups = []
for g_id, group in groupby(instance_anns, lambda a: a.group):
if not g_id:
ann_groups.extend(([a] for a in group))
else:
ann_groups.append(list(group))
return ann_groups
def find_group_leader(group):
return max(group, key=lambda x: x.get_area())
def compute_bbox(annotations):
boxes = [ann.get_bbox() for ann in annotations]
x0 = min((b[0] for b in boxes), default=0)
y0 = min((b[1] for b in boxes), default=0)
x1 = max((b[0] + b[2] for b in boxes), default=0)
y1 = max((b[1] + b[3] for b in boxes), default=0)
return [x0, y0, x1 - x0, y1 - y0]
def softmax(x):
return np.exp(x) / sum(np.exp(x))

@ -0,0 +1,213 @@
# Copyright (C) 2020 Intel Corporation
#
# SPDX-License-Identifier: MIT
from itertools import groupby
import numpy as np
from datumaro.components.extractor import _Shape, Mask, AnnotationType, RleMask
from datumaro.util.mask_tools import mask_to_rle
def find_instances(instance_anns):
instance_anns = sorted(instance_anns, key=lambda a: a.group)
ann_groups = []
for g_id, group in groupby(instance_anns, lambda a: a.group):
if not g_id:
ann_groups.extend(([a] for a in group))
else:
ann_groups.append(list(group))
return ann_groups
def find_group_leader(group):
return max(group, key=lambda x: x.get_area())
def _get_bbox(ann):
if isinstance(ann, (_Shape, Mask)):
return ann.get_bbox()
else:
return ann
def max_bbox(annotations):
boxes = [_get_bbox(ann) for ann in annotations]
x0 = min((b[0] for b in boxes), default=0)
y0 = min((b[1] for b in boxes), default=0)
x1 = max((b[0] + b[2] for b in boxes), default=0)
y1 = max((b[1] + b[3] for b in boxes), default=0)
return [x0, y0, x1 - x0, y1 - y0]
def mean_bbox(annotations):
le = len(annotations)
boxes = [_get_bbox(ann) for ann in annotations]
mlb = sum(b[0] for b in boxes) / le
mtb = sum(b[1] for b in boxes) / le
mrb = sum(b[0] + b[2] for b in boxes) / le
mbb = sum(b[1] + b[3] for b in boxes) / le
return [mlb, mtb, mrb - mlb, mbb - mtb]
def softmax(x):
return np.exp(x) / sum(np.exp(x))
def nms(segments, iou_thresh=0.5):
"""
Non-maxima suppression algorithm.
"""
indices = np.argsort([b.attributes['score'] for b in segments])
ious = np.array([[iou(a, b) for b in segments] for a in segments])
predictions = []
while len(indices) != 0:
i = len(indices) - 1
pred_idx = indices[i]
to_remove = [i]
predictions.append(segments[pred_idx])
for i, box_idx in enumerate(indices[:i]):
if iou_thresh < ious[pred_idx, box_idx]:
to_remove.append(i)
indices = np.delete(indices, to_remove)
return predictions
def bbox_iou(a, b):
"""
IoU computations for simple cases with bounding boxes
"""
bbox_a = _get_bbox(a)
bbox_b = _get_bbox(b)
aX, aY, aW, aH = bbox_a
bX, bY, bW, bH = bbox_b
in_right = min(aX + aW, bX + bW)
in_left = max(aX, bX)
in_top = max(aY, bY)
in_bottom = min(aY + aH, bY + bH)
in_w = max(0, in_right - in_left)
in_h = max(0, in_bottom - in_top)
intersection = in_w * in_h
if not intersection:
return -1
a_area = aW * aH
b_area = bW * bH
union = a_area + b_area - intersection
return intersection / union
def segment_iou(a, b):
"""
Generic IoU computation with masks, polygons, and boxes.
Returns -1 if no intersection, [0; 1] otherwise
"""
from pycocotools import mask as mask_utils
a_bbox = a.get_bbox()
b_bbox = b.get_bbox()
is_bbox = AnnotationType.bbox in [a.type, b.type]
if is_bbox:
a = [a_bbox]
b = [b_bbox]
else:
w = max(a_bbox[0] + a_bbox[2], b_bbox[0] + b_bbox[2])
h = max(a_bbox[1] + a_bbox[3], b_bbox[1] + b_bbox[3])
def _to_rle(ann):
if ann.type == AnnotationType.polygon:
return mask_utils.frPyObjects([ann.points], h, w)
elif isinstance(ann, RleMask):
return [ann._rle]
elif ann.type == AnnotationType.mask:
return mask_utils.frPyObjects([mask_to_rle(ann.image)], h, w)
else:
raise TypeError("Unexpected arguments: %s, %s" % (a, b))
a = _to_rle(a)
b = _to_rle(b)
return float(mask_utils.iou(a, b, [not is_bbox]))
def PDJ(a, b, eps=None, ratio=0.05, bbox=None):
"""
Percentage of Detected Joints metric.
Counts the number of matching points.
"""
assert eps is not None or ratio is not None
p1 = np.array(a.points).reshape((-1, 2))
p2 = np.array(b.points).reshape((-1, 2))
if len(p1) != len(p2):
return 0
if not eps:
if bbox is None:
bbox = mean_bbox([a, b])
diag = (bbox[2] ** 2 + bbox[3] ** 2) ** 0.5
eps = ratio * diag
dists = np.linalg.norm(p1 - p2, axis=1)
return np.sum(dists < eps) / len(p1)
def OKS(a, b, sigma=None, bbox=None, scale=None):
"""
Object Keypoint Similarity metric.
https://cocodataset.org/#keypoints-eval
"""
p1 = np.array(a.points).reshape((-1, 2))
p2 = np.array(b.points).reshape((-1, 2))
if len(p1) != len(p2):
return 0
if not sigma:
sigma = 0.1
else:
assert len(sigma) == len(p1)
if not scale:
if bbox is None:
bbox = mean_bbox([a, b])
scale = bbox[2] * bbox[3]
dists = np.linalg.norm(p1 - p2, axis=1)
return np.sum(np.exp(-(dists ** 2) / (2 * scale * (2 * sigma) ** 2)))
def smooth_line(points, segments):
assert 2 <= len(points) // 2 and len(points) % 2 == 0
if len(points) // 2 == segments:
return points
points = list(points)
if len(points) == 2:
points.extend(points)
points = np.array(points).reshape((-1, 2))
lengths = np.linalg.norm(points[1:] - points[:-1], axis=1)
dists = [0]
for l in lengths:
dists.append(dists[-1] + l)
step = dists[-1] / segments
new_points = np.zeros((segments + 1, 2))
new_points[0] = points[0]
old_segment = 0
for new_segment in range(1, segments + 1):
pos = new_segment * step
while dists[old_segment + 1] < pos and old_segment + 2 < len(dists):
old_segment += 1
segment_start = dists[old_segment]
segment_len = lengths[old_segment]
prev_p = points[old_segment]
next_p = points[old_segment + 1]
r = (pos - segment_start) / segment_len
new_points[new_segment] = prev_p * (1 - r) + next_p * r
return new_points, step

@ -23,4 +23,12 @@ def default_if_none(conv):
elif not isinstance(value, attribute.type or conv):
value = conv(value)
setattr(inst, attribute.name, value)
return validator
return validator
def ensure_cls(c):
def converter(arg):
if isinstance(arg, c):
return arg
else:
return c(**arg)
return converter

@ -65,7 +65,22 @@ def compare_categories(test, expected, actual):
actual[AnnotationType.points].items,
)
def compare_datasets(test, expected, actual):
def _compare_annotations(expected, actual, ignored_attrs=None):
if not ignored_attrs:
return expected == actual
a_attr = expected.attributes
b_attr = actual.attributes
expected.attributes = {k:v for k,v in a_attr.items() if k not in ignored_attrs}
actual.attributes = {k:v for k,v in b_attr.items() if k not in ignored_attrs}
r = expected == actual
expected.attributes = a_attr
actual.attributes = b_attr
return r
def compare_datasets(test, expected, actual, ignored_attrs=None):
compare_categories(test, expected.categories(), actual.categories())
test.assertEqual(sorted(expected.subsets()), sorted(actual.subsets()))
@ -82,8 +97,11 @@ def compare_datasets(test, expected, actual):
if x.type == ann_a.type]
test.assertFalse(len(ann_b_matches) == 0, 'ann id: %s' % ann_a.id)
ann_b = find(ann_b_matches, lambda x: x == ann_a)
test.assertEqual(ann_a, ann_b, 'ann %s, candidates %s' % (ann_a, ann_b_matches))
ann_b = find(ann_b_matches, lambda x:
_compare_annotations(x, ann_a, ignored_attrs=ignored_attrs))
if ann_b is None:
test.assertEqual(ann_a, ann_b,
'ann %s, candidates %s' % (ann_a, ann_b_matches))
item_b.annotations.remove(ann_b) # avoid repeats
def compare_datasets_strict(test, expected, actual):

@ -1,11 +1,14 @@
from unittest import TestCase
import numpy as np
from datumaro.components.extractor import (Extractor, DatasetItem, Label,
Mask, Bbox, Points, Caption)
from datumaro.components.extractor import (Bbox, Caption, DatasetItem,
Extractor, Label, Mask, Points, Polygon, PolyLine)
from datumaro.components.operations import (FailedAttrVotingError,
IntersectMerge, NoMatchingAnnError, NoMatchingItemError, WrongGroupError,
compute_ann_statistics, mean_std)
from datumaro.components.project import Dataset
from datumaro.components.operations import mean_std, compute_ann_statistics
from unittest import TestCase
from datumaro.util.test_utils import compare_datasets
class TestOperations(TestCase):
@ -131,4 +134,234 @@ class TestOperations(TestCase):
actual = compute_ann_statistics(dataset)
self.assertEqual(expected, actual)
self.assertEqual(expected, actual)
class TestMultimerge(TestCase):
def test_can_match_items(self):
# items 1 and 3 are unique, item 2 is common and should be merged
source0 = Dataset.from_iterable([
DatasetItem(1, annotations=[ Label(0), ]),
DatasetItem(2, annotations=[ Label(0), ]),
], categories=['a', 'b'])
source1 = Dataset.from_iterable([
DatasetItem(2, annotations=[ Label(1), ]),
DatasetItem(3, annotations=[ Label(0), ]),
], categories=['a', 'b'])
source2 = Dataset.from_iterable([
DatasetItem(2, annotations=[ Label(0), Bbox(1, 2, 3, 4) ]),
], categories=['a', 'b'])
expected = Dataset.from_iterable([
DatasetItem(1, annotations=[
Label(0, attributes={'score': 1/3}),
]),
DatasetItem(2, annotations=[
Label(0, attributes={'score': 2/3}),
Label(1, attributes={'score': 1/3}),
Bbox(1, 2, 3, 4, attributes={'score': 1.0}),
]),
DatasetItem(3, annotations=[
Label(0, attributes={'score': 1/3}),
]),
], categories=['a', 'b'])
merger = IntersectMerge()
merged = merger([source0, source1, source2])
compare_datasets(self, expected, merged)
self.assertEqual(
[
NoMatchingItemError(item_id=('1', ''), sources={1, 2}),
NoMatchingItemError(item_id=('3', ''), sources={0, 2}),
],
sorted((e for e in merger.errors
if isinstance(e, NoMatchingItemError)),
key=lambda e: e.item_id)
)
self.assertEqual(
[
NoMatchingAnnError(item_id=('2', ''), sources={0, 1},
ann=source2.get('2').annotations[1]),
],
sorted((e for e in merger.errors
if isinstance(e, NoMatchingAnnError)),
key=lambda e: e.item_id)
)
def test_can_match_shapes(self):
source0 = Dataset.from_iterable([
DatasetItem(1, annotations=[
# unique
Bbox(1, 2, 3, 4, label=1),
# common
Mask(label=3, z_order=2, image=np.array([
[0, 0, 0, 0],
[0, 0, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 0],
])),
Polygon([1, 0, 3, 2, 1, 2]),
# an instance with keypoints
Bbox(4, 5, 2, 4, label=2, z_order=1, group=1),
Points([5, 6], label=0, group=1),
Points([6, 8], label=1, group=1),
PolyLine([1, 1, 2, 1, 3, 1]),
]),
], categories=['a', 'b', 'c'])
source1 = Dataset.from_iterable([
DatasetItem(1, annotations=[
# common
Mask(label=3, image=np.array([
[0, 0, 0, 0],
[0, 1, 1, 1],
[0, 1, 1, 1],
[0, 1, 1, 1],
])),
Polygon([0, 2, 2, 0, 2, 1]),
# an instance with keypoints
Bbox(4, 4, 2, 5, label=2, z_order=1, group=2),
Points([5.5, 6.5], label=0, group=2),
Points([6, 8], label=1, group=2),
PolyLine([1, 1.5, 2, 1.5]),
]),
], categories=['a', 'b', 'c'])
source2 = Dataset.from_iterable([
DatasetItem(1, annotations=[
# common
Mask(label=3, z_order=3, image=np.array([
[0, 0, 1, 1],
[0, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 0],
])),
Polygon([3, 1, 2, 2, 0, 1]),
# an instance with keypoints, one is missing
Bbox(3, 6, 2, 3, label=2, z_order=4, group=3),
Points([4.5, 5.5], label=0, group=3),
PolyLine([1, 1.25, 3, 1, 4, 2]),
]),
], categories=['a', 'b', 'c'])
expected = Dataset.from_iterable([
DatasetItem(1, annotations=[
# unique
Bbox(1, 2, 3, 4, label=1),
# common
# nearest to mean bbox
Mask(label=3, z_order=3, image=np.array([
[0, 0, 0, 0],
[0, 1, 1, 1],
[0, 1, 1, 1],
[0, 1, 1, 1],
])),
Polygon([1, 0, 3, 2, 1, 2]),
# an instance with keypoints
Bbox(4, 5, 2, 4, label=2, z_order=4, group=1),
Points([5, 6], label=0, group=1),
Points([6, 8], label=1, group=1),
PolyLine([1, 1.25, 3, 1, 4, 2]),
]),
], categories=['a', 'b', 'c'])
merger = IntersectMerge(conf={'quorum': 1, 'pairwise_dist': 0.1})
merged = merger([source0, source1, source2])
compare_datasets(self, expected, merged, ignored_attrs={'score'})
self.assertEqual(
[
NoMatchingAnnError(item_id=('1', ''), sources={2},
ann=source0.get('1').annotations[5]),
NoMatchingAnnError(item_id=('1', ''), sources={1, 2},
ann=source0.get('1').annotations[0]),
],
sorted((e for e in merger.errors
if isinstance(e, NoMatchingAnnError)),
key=lambda e: len(e.sources))
)
def test_attributes(self):
source0 = Dataset.from_iterable([
DatasetItem(1, annotations=[
Label(2, attributes={
'unique': 1,
'common_under_quorum': 2,
'common_over_quorum': 3,
'ignored': 'q',
}),
]),
], categories=['a', 'b', 'c'])
source1 = Dataset.from_iterable([
DatasetItem(1, annotations=[
Label(2, attributes={
'common_under_quorum': 2,
'common_over_quorum': 3,
'ignored': 'q',
}),
]),
], categories=['a', 'b', 'c'])
source2 = Dataset.from_iterable([
DatasetItem(1, annotations=[
Label(2, attributes={
'common_over_quorum': 3,
'ignored': 'q',
}),
]),
], categories=['a', 'b', 'c'])
expected = Dataset.from_iterable([
DatasetItem(1, annotations=[
Label(2, attributes={ 'common_over_quorum': 3 }),
]),
], categories=['a', 'b', 'c'])
merger = IntersectMerge(conf={
'quorum': 3, 'ignored_attributes': {'ignored'}})
merged = merger([source0, source1, source2])
compare_datasets(self, expected, merged, ignored_attrs={'score'})
self.assertEqual(2, len([e for e in merger.errors
if isinstance(e, FailedAttrVotingError)])
)
def test_group_checks(self):
dataset = Dataset.from_iterable([
DatasetItem(1, annotations=[
Bbox(0, 0, 0, 0, label=0, group=1), # misses an optional label
Bbox(0, 0, 0, 0, label=1, group=1),
Bbox(0, 0, 0, 0, label=2, group=2), # misses a mandatory label - error
Bbox(0, 0, 0, 0, label=2, group=2),
Bbox(0, 0, 0, 0, label=4), # misses an optional label
Bbox(0, 0, 0, 0, label=5), # misses a mandatory label - error
Bbox(0, 0, 0, 0, label=0), # misses a mandatory label - error
Bbox(0, 0, 0, 0, label=3), # not listed - not checked
]),
], categories=['a', 'a_g1', 'a_g2_opt', 'b', 'c', 'c_g1_opt'])
merger = IntersectMerge(conf={'groups': [
['a', 'a_g1', 'a_g2_opt?'], ['c', 'c_g1_opt?']
]})
merger([dataset, dataset])
self.assertEqual(3, len([e for e in merger.errors
if isinstance(e, WrongGroupError)]), merger.errors
)

Loading…
Cancel
Save