[Datumaro] Diff with exact annotation matching (#1989)

* Add exact diff command

* Update changelog

* fix

* fix merge

* Add image matching, add test

* Add point matching test

* linter

* Update CHANGELOG.md

Co-authored-by: Nikita Manovich <nikita.manovich@intel.com>
main
Maxim Zhiltsov 6 years ago committed by GitHub
parent ae6ec40135
commit 98c06a342a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added password reset functionality (<https://github.com/opencv/cvat/pull/2058>)
- Ability to work with data on the fly (https://github.com/opencv/cvat/pull/2007)
- Annotation in process outline color wheel (<https://github.com/opencv/cvat/pull/2084>)
- [Datumaro] CLI command for dataset equality comparison (<https://github.com/opencv/cvat/pull/1989>)
### Changed
- UI models (like DEXTR) were redesigned to be more interactive (<https://github.com/opencv/cvat/pull/2054>)

@ -4,25 +4,26 @@
# SPDX-License-Identifier: MIT
import argparse
from enum import Enum
import json
import logging as log
import os
import os.path as osp
import shutil
from enum import Enum
from datumaro.components.project import Project, Environment, \
PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG
from datumaro.components.comparator import Comparator
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset_filter import DatasetItemEncoder
from datumaro.components.extractor import AnnotationType
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.operations import \
compute_image_statistics, compute_ann_statistics
from datumaro.components.operations import (DistanceComparator,
ExactComparator, compute_ann_statistics, compute_image_statistics, mean_std)
from datumaro.components.project import \
PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG
from datumaro.components.project import Environment, Project
from ...util import (CliException, MultilineFormatter, add_subparser,
make_file_name)
from ...util.project import generate_next_file_name, load_project
from .diff import DiffVisualizer
from ...util import add_subparser, CliException, MultilineFormatter, \
make_file_name
from ...util.project import load_project, generate_next_file_name
def build_create_parser(parser_ctor=argparse.ArgumentParser):
@ -503,12 +504,12 @@ def merge_command(args):
def build_diff_parser(parser_ctor=argparse.ArgumentParser):
parser = parser_ctor(help="Compare projects",
description="""
Compares two projects.|n
Compares two projects, match annotations by distance.|n
|n
Examples:|n
- Compare two projects, consider bboxes matching if their IoU > 0.7,|n
- Compare two projects, match boxes if IoU > 0.7,|n
|s|s|s|sprint results to Tensorboard:
|s|sdiff path/to/other/project -o diff/ -f tensorboard --iou-thresh 0.7
|s|sdiff path/to/other/project -o diff/ -v tensorboard --iou-thresh 0.7
""",
formatter_class=MultilineFormatter)
@ -516,7 +517,7 @@ def build_diff_parser(parser_ctor=argparse.ArgumentParser):
help="Directory of the second project to be compared")
parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None,
help="Directory to save comparison results (default: do not save)")
parser.add_argument('-f', '--format',
parser.add_argument('-v', '--visualizer',
default=DiffVisualizer.DEFAULT_FORMAT,
choices=[f.name for f in DiffVisualizer.Format],
help="Output format (default: %(default)s)")
@ -536,9 +537,7 @@ def diff_command(args):
first_project = load_project(args.project_dir)
second_project = load_project(args.other_project_dir)
comparator = Comparator(
iou_threshold=args.iou_thresh,
conf_threshold=args.conf_thresh)
comparator = DistanceComparator(iou_threshold=args.iou_thresh)
dst_dir = args.dst_dir
if dst_dir:
@ -556,7 +555,7 @@ def diff_command(args):
dst_dir_existed = osp.exists(dst_dir)
try:
visualizer = DiffVisualizer(save_dir=dst_dir, comparator=comparator,
output_format=args.format)
output_format=args.visualizer)
visualizer.save_dataset_diff(
first_project.make_dataset(),
second_project.make_dataset())
@ -567,6 +566,73 @@ def diff_command(args):
return 0
def build_ediff_parser(parser_ctor=argparse.ArgumentParser):
parser = parser_ctor(help="Compare projects for equality",
description="""
Compares two projects for equality.|n
|n
Examples:|n
- Compare two projects, exclude annotation group |n
|s|s|sand the 'is_crowd' attribute from comparison:|n
|s|sediff other/project/ -if group -ia is_crowd
""",
formatter_class=MultilineFormatter)
parser.add_argument('other_project_dir',
help="Directory of the second project to be compared")
parser.add_argument('-iia', '--ignore-item-attr', action='append',
help="Ignore item attribute (repeatable)")
parser.add_argument('-ia', '--ignore-attr', action='append',
help="Ignore annotation attribute (repeatable)")
parser.add_argument('-if', '--ignore-field',
action='append', default=['id', 'group'],
help="Ignore annotation field (repeatable, default: %(default)s)")
parser.add_argument('--match-images', action='store_true',
help='Match dataset items by images instead of ids')
parser.add_argument('--all', action='store_true',
help="Include matches in the output")
parser.add_argument('-p', '--project', dest='project_dir', default='.',
help="Directory of the first project to be compared (default: current dir)")
parser.set_defaults(command=ediff_command)
return parser
def ediff_command(args):
first_project = load_project(args.project_dir)
second_project = load_project(args.other_project_dir)
comparator = ExactComparator(
match_images=args.match_images,
ignored_fields=args.ignore_field,
ignored_attrs=args.ignore_attr,
ignored_item_attrs=args.ignore_item_attr)
matches, mismatches, a_extra, b_extra, errors = \
comparator.compare_datasets(
first_project.make_dataset(), second_project.make_dataset())
output = {
"mismatches": mismatches,
"a_extra_items": sorted(a_extra),
"b_extra_items": sorted(b_extra),
"errors": errors,
}
if args.all:
output["matches"] = matches
output_file = generate_next_file_name('diff', ext='.json')
with open(output_file, 'w') as f:
json.dump(output, f, indent=4, sort_keys=True)
print("Found:")
print("The first project has %s unmatched items" % len(a_extra))
print("The second project has %s unmatched items" % len(b_extra))
print("%s item conflicts" % len(errors))
print("%s matching annotations" % len(matches))
print("%s mismatching annotations" % len(mismatches))
log.info("Output has been saved to '%s'" % output_file)
return 0
def build_transform_parser(parser_ctor=argparse.ArgumentParser):
builtins = sorted(Environment().transforms.items)
@ -753,6 +819,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser):
add_subparser(subparsers, 'extract', build_extract_parser)
add_subparser(subparsers, 'merge', build_merge_parser)
add_subparser(subparsers, 'diff', build_diff_parser)
add_subparser(subparsers, 'ediff', build_ediff_parser)
add_subparser(subparsers, 'transform', build_transform_parser)
add_subparser(subparsers, 'info', build_info_parser)
add_subparser(subparsers, 'stats', build_stats_parser)

@ -217,7 +217,7 @@ class DiffVisualizer:
_, mispred, a_unmatched, b_unmatched = diff
if 0 < len(a_unmatched) + len(b_unmatched) + len(mispred):
img_a = item_a.image.copy()
img_a = item_a.image.data.copy()
img_b = img_a.copy()
for a_bbox, b_bbox in mispred:
self.draw_bbox(img_a, a_bbox, (0, 255, 0))

@ -1,113 +0,0 @@
# Copyright (C) 2019 Intel Corporation
#
# SPDX-License-Identifier: MIT
from itertools import zip_longest
import numpy as np
from datumaro.components.extractor import AnnotationType, LabelCategories
class Comparator:
def __init__(self,
iou_threshold=0.5, conf_threshold=0.9):
self.iou_threshold = iou_threshold
self.conf_threshold = conf_threshold
@staticmethod
def iou(box_a, box_b):
return box_a.iou(box_b)
# pylint: disable=no-self-use
def compare_dataset_labels(self, extractor_a, extractor_b):
a_label_cat = extractor_a.categories().get(AnnotationType.label)
b_label_cat = extractor_b.categories().get(AnnotationType.label)
if not a_label_cat and not b_label_cat:
return None
if not a_label_cat:
a_label_cat = LabelCategories()
if not b_label_cat:
b_label_cat = LabelCategories()
mismatches = []
for a_label, b_label in zip_longest(a_label_cat.items, b_label_cat.items):
if a_label != b_label:
mismatches.append((a_label, b_label))
return mismatches
# pylint: enable=no-self-use
def compare_item_labels(self, item_a, item_b):
conf_threshold = self.conf_threshold
a_labels = set([ann.label for ann in item_a.annotations \
if ann.type is AnnotationType.label and \
conf_threshold < ann.attributes.get('score', 1)])
b_labels = set([ann.label for ann in item_b.annotations \
if ann.type is AnnotationType.label and \
conf_threshold < ann.attributes.get('score', 1)])
a_unmatched = a_labels - b_labels
b_unmatched = b_labels - a_labels
matches = a_labels & b_labels
return matches, a_unmatched, b_unmatched
def compare_item_bboxes(self, item_a, item_b):
iou_threshold = self.iou_threshold
conf_threshold = self.conf_threshold
a_boxes = [ann for ann in item_a.annotations \
if ann.type is AnnotationType.bbox and \
conf_threshold < ann.attributes.get('score', 1)]
b_boxes = [ann for ann in item_b.annotations \
if ann.type is AnnotationType.bbox and \
conf_threshold < ann.attributes.get('score', 1)]
a_boxes.sort(key=lambda ann: 1 - ann.attributes.get('score', 1))
b_boxes.sort(key=lambda ann: 1 - ann.attributes.get('score', 1))
# a_matches: indices of b_boxes matched to a bboxes
# b_matches: indices of a_boxes matched to b bboxes
a_matches = -np.ones(len(a_boxes), dtype=int)
b_matches = -np.ones(len(b_boxes), dtype=int)
iou_matrix = np.array([
[self.iou(a, b) for b in b_boxes] for a in a_boxes
])
# matches: boxes we succeeded to match completely
# mispred: boxes we succeeded to match, having label mismatch
matches = []
mispred = []
for a_idx, a_bbox in enumerate(a_boxes):
if len(b_boxes) == 0:
break
matched_b = a_matches[a_idx]
iou_max = max(iou_matrix[a_idx, matched_b], iou_threshold)
for b_idx, b_bbox in enumerate(b_boxes):
if 0 <= b_matches[b_idx]: # assign a_bbox with max conf
continue
iou = iou_matrix[a_idx, b_idx]
if iou < iou_max:
continue
iou_max = iou
matched_b = b_idx
if matched_b < 0:
continue
a_matches[a_idx] = matched_b
b_matches[matched_b] = a_idx
b_bbox = b_boxes[matched_b]
if a_bbox.label == b_bbox.label:
matches.append( (a_bbox, b_bbox) )
else:
mispred.append( (a_bbox, b_bbox) )
# *_umatched: boxes of (*) we failed to match
a_unmatched = [a_boxes[i] for i, m in enumerate(a_matches) if m < 0]
b_unmatched = [b_boxes[i] for i, m in enumerate(b_matches) if m < 0]
return matches, mispred, a_unmatched, b_unmatched

@ -46,7 +46,7 @@ class Annotation:
@attrs
class Categories:
attributes = attrib(factory=set, validator=default_if_none(set),
kw_only=True)
kw_only=True, eq=False)
@attrs
class LabelCategories(Categories):
@ -137,6 +137,8 @@ class MaskCategories(Categories):
def __eq__(self, other):
if not super().__eq__(other):
return False
if not isinstance(other, __class__):
return False
for label_id, my_color in self.colormap.items():
other_color = other.colormap.get(label_id)
if not np.array_equal(my_color, other_color):
@ -179,6 +181,8 @@ class Mask(Annotation):
def __eq__(self, other):
if not super().__eq__(other):
return False
if not isinstance(other, __class__):
return False
return \
(self.label == other.label) and \
(self.z_order == other.z_order) and \

@ -5,18 +5,20 @@
from collections import OrderedDict
from copy import deepcopy
import hashlib
import logging as log
import attr
import cv2
import numpy as np
from attr import attrib, attrs
from unittest import TestCase
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.extractor import AnnotationType, Bbox, Label
from datumaro.components.project import Dataset
from datumaro.util import find
from datumaro.util.attrs_util import ensure_cls
from datumaro.util import find, filter_dict
from datumaro.util.attrs_util import ensure_cls, default_if_none
from datumaro.util.annotation_util import (segment_iou, bbox_iou,
mean_bbox, OKS, find_instances, max_bbox, smooth_line)
@ -585,7 +587,7 @@ class MaskMatcher(_ShapeMatcher):
@attrs(kw_only=True)
class PointsMatcher(_ShapeMatcher):
sigma = attrib(converter=list, default=None)
sigma = attrib(type=list, default=None)
instance_map = attrib(converter=dict)
def distance(self, a, b):
@ -1003,3 +1005,341 @@ def compute_ann_statistics(dataset):
} for c, (bin_min, bin_max) in zip(hist, zip(bins[:-1], bins[1:]))]
return stats
@attrs
class DistanceComparator:
iou_threshold = attrib(converter=float, default=0.5)
@staticmethod
def match_datasets(a, b):
a_items = set((item.id, item.subset) for item in a)
b_items = set((item.id, item.subset) for item in b)
matches = a_items & b_items
a_unmatched = a_items - b_items
b_unmatched = b_items - a_items
return matches, a_unmatched, b_unmatched
@staticmethod
def match_classes(a, b):
a_label_cat = a.categories().get(AnnotationType.label, LabelCategories())
b_label_cat = b.categories().get(AnnotationType.label, LabelCategories())
a_labels = set(c.name for c in a_label_cat)
b_labels = set(c.name for c in b_label_cat)
matches = a_labels & b_labels
a_unmatched = a_labels - b_labels
b_unmatched = b_labels - a_labels
return matches, a_unmatched, b_unmatched
def match_annotations(self, item_a, item_b):
return { t: self._match_ann_type(t, item_a, item_b) }
def _match_ann_type(self, t, *args):
# pylint: disable=no-value-for-parameter
if t == AnnotationType.label:
return self.match_labels(*args)
elif t == AnnotationType.bbox:
return self.match_boxes(*args)
elif t == AnnotationType.polygon:
return self.match_polygons(*args)
elif t == AnnotationType.mask:
return self.match_masks(*args)
elif t == AnnotationType.points:
return self.match_points(*args)
elif t == AnnotationType.polyline:
return self.match_lines(*args)
# pylint: enable=no-value-for-parameter
else:
raise NotImplementedError("Unexpected annotation type %s" % t)
@staticmethod
def _get_ann_type(t, item):
return get_ann_type(item.annotations, t)
def match_labels(self, item_a, item_b):
a_labels = set(a.label for a in
self._get_ann_type(AnnotationType.label, item_a))
b_labels = set(a.label for a in
self._get_ann_type(AnnotationType.label, item_b))
matches = a_labels & b_labels
a_unmatched = a_labels - b_labels
b_unmatched = b_labels - a_labels
return matches, a_unmatched, b_unmatched
def _match_segments(self, t, item_a, item_b):
a_boxes = self._get_ann_type(t, item_a)
b_boxes = self._get_ann_type(t, item_b)
return match_segments(a_boxes, b_boxes, dist_thresh=self.iou_threshold)
def match_polygons(self, item_a, item_b):
return self._match_segments(AnnotationType.polygon, item_a, item_b)
def match_masks(self, item_a, item_b):
return self._match_segments(AnnotationType.mask, item_a, item_b)
def match_boxes(self, item_a, item_b):
return self._match_segments(AnnotationType.bbox, item_a, item_b)
def match_points(self, item_a, item_b):
a_points = self._get_ann_type(AnnotationType.points, item_a)
b_points = self._get_ann_type(AnnotationType.points, item_b)
instance_map = {}
for s in [item_a.annotations, item_b.annotations]:
s_instances = find_instances(s)
for inst in s_instances:
inst_bbox = max_bbox(inst)
for ann in inst:
instance_map[id(ann)] = [inst, inst_bbox]
matcher = PointsMatcher(instance_map=instance_map)
return match_segments(a_points, b_points,
dist_thresh=self.iou_threshold, distance=matcher.distance)
def match_lines(self, item_a, item_b):
a_lines = self._get_ann_type(AnnotationType.polyline, item_a)
b_lines = self._get_ann_type(AnnotationType.polyline, item_b)
matcher = LineMatcher()
return match_segments(a_lines, b_lines,
dist_thresh=self.iou_threshold, distance=matcher.distance)
def match_items_by_id(a, b):
a_items = set((item.id, item.subset) for item in a)
b_items = set((item.id, item.subset) for item in b)
matches = a_items & b_items
matches = [([m], [m]) for m in matches]
a_unmatched = a_items - b_items
b_unmatched = b_items - a_items
return matches, a_unmatched, b_unmatched
def match_items_by_image_hash(a, b):
def _hash(item):
if not item.image.has_data:
log.warning("Image (%s, %s) has no image "
"data, counted as unmatched", item.id, item.subset)
return None
return hashlib.md5(item.image.data.tobytes()).hexdigest()
def _build_hashmap(source):
d = {}
for item in source:
h = _hash(item)
if h is None:
h = str(id(item)) # anything unique
d.setdefault(h, []).append((item.id, item.subset))
return d
a_hash = _build_hashmap(a)
b_hash = _build_hashmap(b)
a_items = set(a_hash)
b_items = set(b_hash)
matches = a_items & b_items
a_unmatched = a_items - b_items
b_unmatched = b_items - a_items
matches = [(a_hash[h], b_hash[h]) for h in matches]
a_unmatched = set(i for h in a_unmatched for i in a_hash[h])
b_unmatched = set(i for h in b_unmatched for i in b_hash[h])
return matches, a_unmatched, b_unmatched
@attrs
class ExactComparator:
match_images = attrib(kw_only=True, type=bool, default=False)
ignored_fields = attrib(kw_only=True,
factory=set, validator=default_if_none(set))
ignored_attrs = attrib(kw_only=True,
factory=set, validator=default_if_none(set))
ignored_item_attrs = attrib(kw_only=True,
factory=set, validator=default_if_none(set))
_test = attrib(init=False, type=TestCase)
errors = attrib(init=False, type=list)
def __attrs_post_init__(self):
self._test = TestCase()
self._test.maxDiff = None
def _match_items(self, a, b):
if self.match_images:
return match_items_by_image_hash(a, b)
else:
return match_items_by_id(a, b)
def _compare_categories(self, a, b):
test = self._test
errors = self.errors
try:
test.assertEqual(
sorted(a, key=lambda t: t.value),
sorted(b, key=lambda t: t.value)
)
except AssertionError as e:
errors.append({'type': 'categories', 'message': str(e)})
if AnnotationType.label in a:
try:
test.assertEqual(
a[AnnotationType.label].items,
b[AnnotationType.label].items,
)
except AssertionError as e:
errors.append({'type': 'labels', 'message': str(e)})
if AnnotationType.mask in a:
try:
test.assertEqual(
a[AnnotationType.mask].colormap,
b[AnnotationType.mask].colormap,
)
except AssertionError as e:
errors.append({'type': 'colormap', 'message': str(e)})
if AnnotationType.points in a:
try:
test.assertEqual(
a[AnnotationType.points].items,
b[AnnotationType.points].items,
)
except AssertionError as e:
errors.append({'type': 'points', 'message': str(e)})
def _compare_annotations(self, a, b):
ignored_fields = self.ignored_fields
ignored_attrs = self.ignored_attrs
a_fields = { k: None for k in vars(a) if k in ignored_fields }
b_fields = { k: None for k in vars(b) if k in ignored_fields }
if 'attributes' not in ignored_fields:
a_fields['attributes'] = filter_dict(a.attributes, ignored_attrs)
b_fields['attributes'] = filter_dict(b.attributes, ignored_attrs)
result = a.wrap(**a_fields) == b.wrap(**b_fields)
return result
def _compare_items(self, item_a, item_b):
test = self._test
a_id = (item_a.id, item_a.subset)
b_id = (item_b.id, item_b.subset)
matched = []
unmatched = []
errors = []
try:
test.assertEqual(
filter_dict(item_a.attributes, self.ignored_item_attrs),
filter_dict(item_b.attributes, self.ignored_item_attrs)
)
except AssertionError as e:
errors.append({'type': 'item_attr',
'a_item': a_id, 'b_item': b_id, 'message': str(e)})
b_annotations = item_b.annotations[:]
for ann_a in item_a.annotations:
ann_b_candidates = [x for x in item_b.annotations
if x.type == ann_a.type]
ann_b = find(enumerate(self._compare_annotations(ann_a, x)
for x in ann_b_candidates), lambda x: x[1])
if ann_b is None:
unmatched.append({
'item': a_id, 'source': 'a', 'ann': str(ann_a),
})
continue
else:
ann_b = ann_b_candidates[ann_b[0]]
b_annotations.remove(ann_b) # avoid repeats
matched.append({'a_item': a_id, 'b_item': b_id,
'a': str(ann_a), 'b': str(ann_b)})
for ann_b in b_annotations:
unmatched.append({'item': b_id, 'source': 'b', 'ann': str(ann_b)})
return matched, unmatched, errors
def compare_datasets(self, a, b):
self.errors = []
errors = self.errors
self._compare_categories(a.categories(), b.categories())
matched = []
unmatched = []
matches, a_unmatched, b_unmatched = self._match_items(a, b)
if a.categories().get(AnnotationType.label) != \
b.categories().get(AnnotationType.label):
return matched, unmatched, a_unmatched, b_unmatched, errors
_dist = lambda s: len(s[1]) + len(s[2])
for a_ids, b_ids in matches:
# build distance matrix
match_status = {} # (a_id, b_id): [matched, unmatched, errors]
a_matches = { a_id: None for a_id in a_ids }
b_matches = { b_id: None for b_id in b_ids }
for a_id in a_ids:
item_a = a.get(*a_id)
candidates = {}
for b_id in b_ids:
item_b = b.get(*b_id)
i_m, i_um, i_err = self._compare_items(item_a, item_b)
candidates[b_id] = [i_m, i_um, i_err]
if len(i_um) == 0:
a_matches[a_id] = b_id
b_matches[b_id] = a_id
matched.extend(i_m)
errors.extend(i_err)
break
match_status[a_id] = candidates
# assign
for a_id in a_ids:
if len(b_ids) == 0:
break
# find the closest, ignore already assigned
matched_b = a_matches[a_id]
if matched_b is not None:
continue
min_dist = -1
for b_id in b_ids:
if b_matches[b_id] is not None:
continue
d = _dist(match_status[a_id][b_id])
if d < min_dist and 0 <= min_dist:
continue
min_dist = d
matched_b = b_id
if matched_b is None:
continue
a_matches[a_id] = matched_b
b_matches[matched_b] = a_id
m = match_status[a_id][matched_b]
matched.extend(m[0])
unmatched.extend(m[1])
errors.extend(m[2])
a_unmatched |= set(a_id for a_id, m in a_matches.items() if not m)
b_unmatched |= set(b_id for b_id, m in b_matches.items() if not m)
return matched, unmatched, a_unmatched, b_unmatched, errors

@ -88,3 +88,6 @@ def str_to_bool(s):
return False
else:
raise ValueError("Can't convert value '%s' to bool" % s)
def filter_dict(d, exclude_keys):
return { k: v for k, v in d.items() if k not in exclude_keys }

@ -100,8 +100,7 @@ def compare_datasets(test, expected, actual, ignored_attrs=None):
ann_b = find(ann_b_matches, lambda x:
_compare_annotations(x, ann_a, ignored_attrs=ignored_attrs))
if ann_b is None:
test.assertEqual(ann_a, ann_b,
'ann %s, candidates %s' % (ann_a, ann_b_matches))
test.fail('ann %s, candidates %s' % (ann_a, ann_b_matches))
item_b.annotations.remove(ann_b) # avoid repeats
def compare_datasets_strict(test, expected, actual):

@ -1,123 +1,96 @@
from unittest import TestCase
import numpy as np
from datumaro.components.extractor import (DatasetItem, Label, Bbox,
Caption, Mask, Points)
from datumaro.components.project import Dataset
from datumaro.components.operations import DistanceComparator, ExactComparator
from datumaro.components.extractor import DatasetItem, Label, Bbox
from datumaro.components.comparator import Comparator
from unittest import TestCase
class DiffTest(TestCase):
class DistanceComparatorTest(TestCase):
def test_no_bbox_diff_with_same_item(self):
detections = 3
anns = [
Bbox(i * 10, 10, 10, 10, label=i,
attributes={'score': (1.0 + i) / detections}) \
for i in range(detections)
Bbox(i * 10, 10, 10, 10, label=i)
for i in range(detections)
]
item = DatasetItem(id=0, annotations=anns)
iou_thresh = 0.5
conf_thresh = 0.5
comp = Comparator(
iou_threshold=iou_thresh, conf_threshold=conf_thresh)
comp = DistanceComparator(iou_threshold=iou_thresh)
result = comp.compare_item_bboxes(item, item)
result = comp.match_boxes(item, item)
matches, mispred, a_greater, b_greater = result
self.assertEqual(0, len(mispred))
self.assertEqual(0, len(a_greater))
self.assertEqual(0, len(b_greater))
self.assertEqual(len([it for it in item.annotations \
if conf_thresh < it.attributes['score']]),
len(matches))
self.assertEqual(len(item.annotations), len(matches))
for a_bbox, b_bbox in matches:
self.assertLess(iou_thresh, a_bbox.iou(b_bbox))
self.assertEqual(a_bbox.label, b_bbox.label)
self.assertLess(conf_thresh, a_bbox.attributes['score'])
self.assertLess(conf_thresh, b_bbox.attributes['score'])
def test_can_find_bbox_with_wrong_label(self):
detections = 3
class_count = 2
item1 = DatasetItem(id=1, annotations=[
Bbox(i * 10, 10, 10, 10, label=i,
attributes={'score': (1.0 + i) / detections}) \
for i in range(detections)
Bbox(i * 10, 10, 10, 10, label=i)
for i in range(detections)
])
item2 = DatasetItem(id=2, annotations=[
Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count,
attributes={'score': (1.0 + i) / detections}) \
for i in range(detections)
Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count)
for i in range(detections)
])
iou_thresh = 0.5
conf_thresh = 0.5
comp = Comparator(
iou_threshold=iou_thresh, conf_threshold=conf_thresh)
comp = DistanceComparator(iou_threshold=iou_thresh)
result = comp.compare_item_bboxes(item1, item2)
result = comp.match_boxes(item1, item2)
matches, mispred, a_greater, b_greater = result
self.assertEqual(len([it for it in item1.annotations \
if conf_thresh < it.attributes['score']]),
len(mispred))
self.assertEqual(len(item1.annotations), len(mispred))
self.assertEqual(0, len(a_greater))
self.assertEqual(0, len(b_greater))
self.assertEqual(0, len(matches))
for a_bbox, b_bbox in mispred:
self.assertLess(iou_thresh, a_bbox.iou(b_bbox))
self.assertEqual((a_bbox.label + 1) % class_count, b_bbox.label)
self.assertLess(conf_thresh, a_bbox.attributes['score'])
self.assertLess(conf_thresh, b_bbox.attributes['score'])
def test_can_find_missing_boxes(self):
detections = 3
class_count = 2
item1 = DatasetItem(id=1, annotations=[
Bbox(i * 10, 10, 10, 10, label=i,
attributes={'score': (1.0 + i) / detections}) \
for i in range(detections) if i % 2 == 0
Bbox(i * 10, 10, 10, 10, label=i)
for i in range(detections) if i % 2 == 0
])
item2 = DatasetItem(id=2, annotations=[
Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count,
attributes={'score': (1.0 + i) / detections}) \
for i in range(detections) if i % 2 == 1
Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count)
for i in range(detections) if i % 2 == 1
])
iou_thresh = 0.5
conf_thresh = 0.5
comp = Comparator(
iou_threshold=iou_thresh, conf_threshold=conf_thresh)
comp = DistanceComparator(iou_threshold=iou_thresh)
result = comp.compare_item_bboxes(item1, item2)
result = comp.match_boxes(item1, item2)
matches, mispred, a_greater, b_greater = result
self.assertEqual(0, len(mispred))
self.assertEqual(len([it for it in item1.annotations \
if conf_thresh < it.attributes['score']]),
len(a_greater))
self.assertEqual(len([it for it in item2.annotations \
if conf_thresh < it.attributes['score']]),
len(b_greater))
self.assertEqual(len(item1.annotations), len(a_greater))
self.assertEqual(len(item2.annotations), len(b_greater))
self.assertEqual(0, len(matches))
def test_no_label_diff_with_same_item(self):
detections = 3
anns = [
Label(i, attributes={'score': (1.0 + i) / detections}) \
for i in range(detections)
]
anns = [ Label(i) for i in range(detections) ]
item = DatasetItem(id=1, annotations=anns)
conf_thresh = 0.5
comp = Comparator(conf_threshold=conf_thresh)
result = comp.compare_item_labels(item, item)
result = DistanceComparator().match_labels(item, item)
matches, a_greater, b_greater = result
self.assertEqual(0, len(a_greater))
self.assertEqual(0, len(b_greater))
self.assertEqual(len([it for it in item.annotations \
if conf_thresh < it.attributes['score']]),
len(matches))
self.assertEqual(len(item.annotations), len(matches))
def test_can_find_wrong_label(self):
item1 = DatasetItem(id=1, annotations=[
@ -131,12 +104,148 @@ class DiffTest(TestCase):
Label(4),
])
conf_thresh = 0.5
comp = Comparator(conf_threshold=conf_thresh)
result = comp.compare_item_labels(item1, item2)
result = DistanceComparator().match_labels(item1, item2)
matches, a_greater, b_greater = result
self.assertEqual(2, len(a_greater))
self.assertEqual(2, len(b_greater))
self.assertEqual(1, len(matches))
self.assertEqual(1, len(matches))
def test_can_match_points(self):
item1 = DatasetItem(id=1, annotations=[
Points([1, 2, 2, 0, 1, 1], label=0),
Points([3, 5, 5, 7, 5, 3], label=0),
])
item2 = DatasetItem(id=2, annotations=[
Points([1.5, 2, 2, 0.5, 1, 1.5], label=0),
Points([5, 7, 7, 7, 7, 5], label=0),
])
result = DistanceComparator().match_points(item1, item2)
matches, mismatches, a_greater, b_greater = result
self.assertEqual(1, len(a_greater))
self.assertEqual(1, len(b_greater))
self.assertEqual(1, len(matches))
self.assertEqual(0, len(mismatches))
class ExactComparatorTest(TestCase):
def test_class_comparison(self):
a = Dataset.from_iterable([], categories=['a', 'b', 'c'])
b = Dataset.from_iterable([], categories=['b', 'c'])
comp = ExactComparator()
_, _, _, _, errors = comp.compare_datasets(a, b)
self.assertEqual(1, len(errors), errors)
def test_item_comparison(self):
a = Dataset.from_iterable([
DatasetItem(id=1, subset='train'),
DatasetItem(id=2, subset='test', attributes={'x': 1}),
], categories=['a', 'b', 'c'])
b = Dataset.from_iterable([
DatasetItem(id=2, subset='test'),
DatasetItem(id=3),
], categories=['a', 'b', 'c'])
comp = ExactComparator()
_, _, a_extra_items, b_extra_items, errors = comp.compare_datasets(a, b)
self.assertEqual({('1', 'train')}, a_extra_items)
self.assertEqual({('3', '')}, b_extra_items)
self.assertEqual(1, len(errors), errors)
def test_annotation_comparison(self):
a = Dataset.from_iterable([
DatasetItem(id=1, annotations=[
Caption('hello'), # unmatched
Caption('world', group=5),
Label(2, attributes={ 'x': 1, 'y': '2', }),
Bbox(1, 2, 3, 4, label=4, z_order=1, attributes={
'score': 1.0,
}),
Bbox(5, 6, 7, 8, group=5),
Points([1, 2, 2, 0, 1, 1], label=0, z_order=4),
Mask(label=3, z_order=2, image=np.ones((2, 3))),
]),
], categories=['a', 'b', 'c', 'd'])
b = Dataset.from_iterable([
DatasetItem(id=1, annotations=[
Caption('world', group=5),
Label(2, attributes={ 'x': 1, 'y': '2', }),
Bbox(1, 2, 3, 4, label=4, z_order=1, attributes={
'score': 1.0,
}),
Bbox(5, 6, 7, 8, group=5),
Bbox(5, 6, 7, 8, group=5), # unmatched
Points([1, 2, 2, 0, 1, 1], label=0, z_order=4),
Mask(label=3, z_order=2, image=np.ones((2, 3))),
]),
], categories=['a', 'b', 'c', 'd'])
comp = ExactComparator()
matched, unmatched, _, _, errors = comp.compare_datasets(a, b)
self.assertEqual(6, len(matched), matched)
self.assertEqual(2, len(unmatched), unmatched)
self.assertEqual(0, len(errors), errors)
def test_image_comparison(self):
a = Dataset.from_iterable([
DatasetItem(id=11, image=np.ones((5, 4, 3)), annotations=[
Bbox(5, 6, 7, 8),
]),
DatasetItem(id=12, image=np.ones((5, 4, 3)), annotations=[
Bbox(1, 2, 3, 4),
Bbox(5, 6, 7, 8),
]),
DatasetItem(id=13, image=np.ones((5, 4, 3)), annotations=[
Bbox(9, 10, 11, 12), # mismatch
]),
DatasetItem(id=14, image=np.zeros((5, 4, 3)), annotations=[
Bbox(1, 2, 3, 4),
Bbox(5, 6, 7, 8),
], attributes={ 'a': 1 }),
DatasetItem(id=15, image=np.zeros((5, 5, 3)), annotations=[
Bbox(1, 2, 3, 4),
Bbox(5, 6, 7, 8),
]),
], categories=['a', 'b', 'c', 'd'])
b = Dataset.from_iterable([
DatasetItem(id=21, image=np.ones((5, 4, 3)), annotations=[
Bbox(5, 6, 7, 8),
]),
DatasetItem(id=22, image=np.ones((5, 4, 3)), annotations=[
Bbox(1, 2, 3, 4),
Bbox(5, 6, 7, 8),
]),
DatasetItem(id=23, image=np.ones((5, 4, 3)), annotations=[
Bbox(10, 10, 11, 12), # mismatch
]),
DatasetItem(id=24, image=np.zeros((5, 4, 3)), annotations=[
Bbox(6, 6, 7, 8), # 1 ann missing, mismatch
], attributes={ 'a': 2 }),
DatasetItem(id=25, image=np.zeros((4, 4, 3)), annotations=[
Bbox(6, 6, 7, 8),
]),
], categories=['a', 'b', 'c', 'd'])
comp = ExactComparator(match_images=True)
matched_ann, unmatched_ann, a_unmatched, b_unmatched, errors = \
comp.compare_datasets(a, b)
self.assertEqual(3, len(matched_ann), matched_ann)
self.assertEqual(5, len(unmatched_ann), unmatched_ann)
self.assertEqual(1, len(a_unmatched), a_unmatched)
self.assertEqual(1, len(b_unmatched), b_unmatched)
self.assertEqual(1, len(errors), errors)
Loading…
Cancel
Save