You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

251 lines
8.9 KiB
Python

import numpy as np
from datumaro.components.extractor import (DatasetItem, Label, Bbox,
Caption, Mask, Points)
from datumaro.components.project import Dataset
from datumaro.components.operations import DistanceComparator, ExactComparator
from unittest import TestCase
class DistanceComparatorTest(TestCase):
def test_no_bbox_diff_with_same_item(self):
detections = 3
anns = [
Bbox(i * 10, 10, 10, 10, label=i)
for i in range(detections)
]
item = DatasetItem(id=0, annotations=anns)
iou_thresh = 0.5
comp = DistanceComparator(iou_threshold=iou_thresh)
result = comp.match_boxes(item, item)
matches, mispred, a_greater, b_greater = result
self.assertEqual(0, len(mispred))
self.assertEqual(0, len(a_greater))
self.assertEqual(0, len(b_greater))
self.assertEqual(len(item.annotations), len(matches))
for a_bbox, b_bbox in matches:
self.assertLess(iou_thresh, a_bbox.iou(b_bbox))
self.assertEqual(a_bbox.label, b_bbox.label)
def test_can_find_bbox_with_wrong_label(self):
detections = 3
class_count = 2
item1 = DatasetItem(id=1, annotations=[
Bbox(i * 10, 10, 10, 10, label=i)
for i in range(detections)
])
item2 = DatasetItem(id=2, annotations=[
Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count)
for i in range(detections)
])
iou_thresh = 0.5
comp = DistanceComparator(iou_threshold=iou_thresh)
result = comp.match_boxes(item1, item2)
matches, mispred, a_greater, b_greater = result
self.assertEqual(len(item1.annotations), len(mispred))
self.assertEqual(0, len(a_greater))
self.assertEqual(0, len(b_greater))
self.assertEqual(0, len(matches))
for a_bbox, b_bbox in mispred:
self.assertLess(iou_thresh, a_bbox.iou(b_bbox))
self.assertEqual((a_bbox.label + 1) % class_count, b_bbox.label)
def test_can_find_missing_boxes(self):
detections = 3
class_count = 2
item1 = DatasetItem(id=1, annotations=[
Bbox(i * 10, 10, 10, 10, label=i)
for i in range(detections) if i % 2 == 0
])
item2 = DatasetItem(id=2, annotations=[
Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count)
for i in range(detections) if i % 2 == 1
])
iou_thresh = 0.5
comp = DistanceComparator(iou_threshold=iou_thresh)
result = comp.match_boxes(item1, item2)
matches, mispred, a_greater, b_greater = result
self.assertEqual(0, len(mispred))
self.assertEqual(len(item1.annotations), len(a_greater))
self.assertEqual(len(item2.annotations), len(b_greater))
self.assertEqual(0, len(matches))
def test_no_label_diff_with_same_item(self):
detections = 3
anns = [ Label(i) for i in range(detections) ]
item = DatasetItem(id=1, annotations=anns)
result = DistanceComparator().match_labels(item, item)
matches, a_greater, b_greater = result
self.assertEqual(0, len(a_greater))
self.assertEqual(0, len(b_greater))
self.assertEqual(len(item.annotations), len(matches))
def test_can_find_wrong_label(self):
item1 = DatasetItem(id=1, annotations=[
Label(0),
Label(1),
Label(2),
])
item2 = DatasetItem(id=2, annotations=[
Label(2),
Label(3),
Label(4),
])
result = DistanceComparator().match_labels(item1, item2)
matches, a_greater, b_greater = result
self.assertEqual(2, len(a_greater))
self.assertEqual(2, len(b_greater))
self.assertEqual(1, len(matches))
def test_can_match_points(self):
item1 = DatasetItem(id=1, annotations=[
Points([1, 2, 2, 0, 1, 1], label=0),
Points([3, 5, 5, 7, 5, 3], label=0),
])
item2 = DatasetItem(id=2, annotations=[
Points([1.5, 2, 2, 0.5, 1, 1.5], label=0),
Points([5, 7, 7, 7, 7, 5], label=0),
])
result = DistanceComparator().match_points(item1, item2)
matches, mismatches, a_greater, b_greater = result
self.assertEqual(1, len(a_greater))
self.assertEqual(1, len(b_greater))
self.assertEqual(1, len(matches))
self.assertEqual(0, len(mismatches))
class ExactComparatorTest(TestCase):
def test_class_comparison(self):
a = Dataset.from_iterable([], categories=['a', 'b', 'c'])
b = Dataset.from_iterable([], categories=['b', 'c'])
comp = ExactComparator()
_, _, _, _, errors = comp.compare_datasets(a, b)
self.assertEqual(1, len(errors), errors)
def test_item_comparison(self):
a = Dataset.from_iterable([
DatasetItem(id=1, subset='train'),
DatasetItem(id=2, subset='test', attributes={'x': 1}),
], categories=['a', 'b', 'c'])
b = Dataset.from_iterable([
DatasetItem(id=2, subset='test'),
DatasetItem(id=3),
], categories=['a', 'b', 'c'])
comp = ExactComparator()
_, _, a_extra_items, b_extra_items, errors = comp.compare_datasets(a, b)
self.assertEqual({('1', 'train')}, a_extra_items)
self.assertEqual({('3', '')}, b_extra_items)
self.assertEqual(1, len(errors), errors)
def test_annotation_comparison(self):
a = Dataset.from_iterable([
DatasetItem(id=1, annotations=[
Caption('hello'), # unmatched
Caption('world', group=5),
Label(2, attributes={ 'x': 1, 'y': '2', }),
Bbox(1, 2, 3, 4, label=4, z_order=1, attributes={
'score': 1.0,
}),
Bbox(5, 6, 7, 8, group=5),
Points([1, 2, 2, 0, 1, 1], label=0, z_order=4),
Mask(label=3, z_order=2, image=np.ones((2, 3))),
]),
], categories=['a', 'b', 'c', 'd'])
b = Dataset.from_iterable([
DatasetItem(id=1, annotations=[
Caption('world', group=5),
Label(2, attributes={ 'x': 1, 'y': '2', }),
Bbox(1, 2, 3, 4, label=4, z_order=1, attributes={
'score': 1.0,
}),
Bbox(5, 6, 7, 8, group=5),
Bbox(5, 6, 7, 8, group=5), # unmatched
Points([1, 2, 2, 0, 1, 1], label=0, z_order=4),
Mask(label=3, z_order=2, image=np.ones((2, 3))),
]),
], categories=['a', 'b', 'c', 'd'])
comp = ExactComparator()
matched, unmatched, _, _, errors = comp.compare_datasets(a, b)
self.assertEqual(6, len(matched), matched)
self.assertEqual(2, len(unmatched), unmatched)
self.assertEqual(0, len(errors), errors)
def test_image_comparison(self):
a = Dataset.from_iterable([
DatasetItem(id=11, image=np.ones((5, 4, 3)), annotations=[
Bbox(5, 6, 7, 8),
]),
DatasetItem(id=12, image=np.ones((5, 4, 3)), annotations=[
Bbox(1, 2, 3, 4),
Bbox(5, 6, 7, 8),
]),
DatasetItem(id=13, image=np.ones((5, 4, 3)), annotations=[
Bbox(9, 10, 11, 12), # mismatch
]),
DatasetItem(id=14, image=np.zeros((5, 4, 3)), annotations=[
Bbox(1, 2, 3, 4),
Bbox(5, 6, 7, 8),
], attributes={ 'a': 1 }),
DatasetItem(id=15, image=np.zeros((5, 5, 3)), annotations=[
Bbox(1, 2, 3, 4),
Bbox(5, 6, 7, 8),
]),
], categories=['a', 'b', 'c', 'd'])
b = Dataset.from_iterable([
DatasetItem(id=21, image=np.ones((5, 4, 3)), annotations=[
Bbox(5, 6, 7, 8),
]),
DatasetItem(id=22, image=np.ones((5, 4, 3)), annotations=[
Bbox(1, 2, 3, 4),
Bbox(5, 6, 7, 8),
]),
DatasetItem(id=23, image=np.ones((5, 4, 3)), annotations=[
Bbox(10, 10, 11, 12), # mismatch
]),
DatasetItem(id=24, image=np.zeros((5, 4, 3)), annotations=[
Bbox(6, 6, 7, 8), # 1 ann missing, mismatch
], attributes={ 'a': 2 }),
DatasetItem(id=25, image=np.zeros((4, 4, 3)), annotations=[
Bbox(6, 6, 7, 8),
]),
], categories=['a', 'b', 'c', 'd'])
comp = ExactComparator(match_images=True)
matched_ann, unmatched_ann, a_unmatched, b_unmatched, errors = \
comp.compare_datasets(a, b)
self.assertEqual(3, len(matched_ann), matched_ann)
self.assertEqual(5, len(unmatched_ann), unmatched_ann)
self.assertEqual(1, len(a_unmatched), a_unmatched)
self.assertEqual(1, len(b_unmatched), b_unmatched)
self.assertEqual(1, len(errors), errors)