You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
251 lines
8.9 KiB
Python
251 lines
8.9 KiB
Python
import numpy as np
|
|
|
|
from datumaro.components.extractor import (DatasetItem, Label, Bbox,
|
|
Caption, Mask, Points)
|
|
from datumaro.components.project import Dataset
|
|
from datumaro.components.operations import DistanceComparator, ExactComparator
|
|
|
|
from unittest import TestCase
|
|
|
|
|
|
class DistanceComparatorTest(TestCase):
|
|
def test_no_bbox_diff_with_same_item(self):
|
|
detections = 3
|
|
anns = [
|
|
Bbox(i * 10, 10, 10, 10, label=i)
|
|
for i in range(detections)
|
|
]
|
|
item = DatasetItem(id=0, annotations=anns)
|
|
|
|
iou_thresh = 0.5
|
|
comp = DistanceComparator(iou_threshold=iou_thresh)
|
|
|
|
result = comp.match_boxes(item, item)
|
|
|
|
matches, mispred, a_greater, b_greater = result
|
|
self.assertEqual(0, len(mispred))
|
|
self.assertEqual(0, len(a_greater))
|
|
self.assertEqual(0, len(b_greater))
|
|
self.assertEqual(len(item.annotations), len(matches))
|
|
for a_bbox, b_bbox in matches:
|
|
self.assertLess(iou_thresh, a_bbox.iou(b_bbox))
|
|
self.assertEqual(a_bbox.label, b_bbox.label)
|
|
|
|
def test_can_find_bbox_with_wrong_label(self):
|
|
detections = 3
|
|
class_count = 2
|
|
item1 = DatasetItem(id=1, annotations=[
|
|
Bbox(i * 10, 10, 10, 10, label=i)
|
|
for i in range(detections)
|
|
])
|
|
item2 = DatasetItem(id=2, annotations=[
|
|
Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count)
|
|
for i in range(detections)
|
|
])
|
|
|
|
iou_thresh = 0.5
|
|
comp = DistanceComparator(iou_threshold=iou_thresh)
|
|
|
|
result = comp.match_boxes(item1, item2)
|
|
|
|
matches, mispred, a_greater, b_greater = result
|
|
self.assertEqual(len(item1.annotations), len(mispred))
|
|
self.assertEqual(0, len(a_greater))
|
|
self.assertEqual(0, len(b_greater))
|
|
self.assertEqual(0, len(matches))
|
|
for a_bbox, b_bbox in mispred:
|
|
self.assertLess(iou_thresh, a_bbox.iou(b_bbox))
|
|
self.assertEqual((a_bbox.label + 1) % class_count, b_bbox.label)
|
|
|
|
def test_can_find_missing_boxes(self):
|
|
detections = 3
|
|
class_count = 2
|
|
item1 = DatasetItem(id=1, annotations=[
|
|
Bbox(i * 10, 10, 10, 10, label=i)
|
|
for i in range(detections) if i % 2 == 0
|
|
])
|
|
item2 = DatasetItem(id=2, annotations=[
|
|
Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count)
|
|
for i in range(detections) if i % 2 == 1
|
|
])
|
|
|
|
iou_thresh = 0.5
|
|
comp = DistanceComparator(iou_threshold=iou_thresh)
|
|
|
|
result = comp.match_boxes(item1, item2)
|
|
|
|
matches, mispred, a_greater, b_greater = result
|
|
self.assertEqual(0, len(mispred))
|
|
self.assertEqual(len(item1.annotations), len(a_greater))
|
|
self.assertEqual(len(item2.annotations), len(b_greater))
|
|
self.assertEqual(0, len(matches))
|
|
|
|
def test_no_label_diff_with_same_item(self):
|
|
detections = 3
|
|
anns = [ Label(i) for i in range(detections) ]
|
|
item = DatasetItem(id=1, annotations=anns)
|
|
|
|
result = DistanceComparator().match_labels(item, item)
|
|
|
|
matches, a_greater, b_greater = result
|
|
self.assertEqual(0, len(a_greater))
|
|
self.assertEqual(0, len(b_greater))
|
|
self.assertEqual(len(item.annotations), len(matches))
|
|
|
|
def test_can_find_wrong_label(self):
|
|
item1 = DatasetItem(id=1, annotations=[
|
|
Label(0),
|
|
Label(1),
|
|
Label(2),
|
|
])
|
|
item2 = DatasetItem(id=2, annotations=[
|
|
Label(2),
|
|
Label(3),
|
|
Label(4),
|
|
])
|
|
|
|
result = DistanceComparator().match_labels(item1, item2)
|
|
|
|
matches, a_greater, b_greater = result
|
|
self.assertEqual(2, len(a_greater))
|
|
self.assertEqual(2, len(b_greater))
|
|
self.assertEqual(1, len(matches))
|
|
|
|
def test_can_match_points(self):
|
|
item1 = DatasetItem(id=1, annotations=[
|
|
Points([1, 2, 2, 0, 1, 1], label=0),
|
|
|
|
Points([3, 5, 5, 7, 5, 3], label=0),
|
|
])
|
|
item2 = DatasetItem(id=2, annotations=[
|
|
Points([1.5, 2, 2, 0.5, 1, 1.5], label=0),
|
|
|
|
Points([5, 7, 7, 7, 7, 5], label=0),
|
|
])
|
|
|
|
result = DistanceComparator().match_points(item1, item2)
|
|
|
|
matches, mismatches, a_greater, b_greater = result
|
|
self.assertEqual(1, len(a_greater))
|
|
self.assertEqual(1, len(b_greater))
|
|
self.assertEqual(1, len(matches))
|
|
self.assertEqual(0, len(mismatches))
|
|
|
|
class ExactComparatorTest(TestCase):
|
|
def test_class_comparison(self):
|
|
a = Dataset.from_iterable([], categories=['a', 'b', 'c'])
|
|
b = Dataset.from_iterable([], categories=['b', 'c'])
|
|
|
|
comp = ExactComparator()
|
|
_, _, _, _, errors = comp.compare_datasets(a, b)
|
|
|
|
self.assertEqual(1, len(errors), errors)
|
|
|
|
def test_item_comparison(self):
|
|
a = Dataset.from_iterable([
|
|
DatasetItem(id=1, subset='train'),
|
|
DatasetItem(id=2, subset='test', attributes={'x': 1}),
|
|
], categories=['a', 'b', 'c'])
|
|
|
|
b = Dataset.from_iterable([
|
|
DatasetItem(id=2, subset='test'),
|
|
DatasetItem(id=3),
|
|
], categories=['a', 'b', 'c'])
|
|
|
|
comp = ExactComparator()
|
|
_, _, a_extra_items, b_extra_items, errors = comp.compare_datasets(a, b)
|
|
|
|
self.assertEqual({('1', 'train')}, a_extra_items)
|
|
self.assertEqual({('3', '')}, b_extra_items)
|
|
self.assertEqual(1, len(errors), errors)
|
|
|
|
def test_annotation_comparison(self):
|
|
a = Dataset.from_iterable([
|
|
DatasetItem(id=1, annotations=[
|
|
Caption('hello'), # unmatched
|
|
Caption('world', group=5),
|
|
Label(2, attributes={ 'x': 1, 'y': '2', }),
|
|
Bbox(1, 2, 3, 4, label=4, z_order=1, attributes={
|
|
'score': 1.0,
|
|
}),
|
|
Bbox(5, 6, 7, 8, group=5),
|
|
Points([1, 2, 2, 0, 1, 1], label=0, z_order=4),
|
|
Mask(label=3, z_order=2, image=np.ones((2, 3))),
|
|
]),
|
|
], categories=['a', 'b', 'c', 'd'])
|
|
|
|
b = Dataset.from_iterable([
|
|
DatasetItem(id=1, annotations=[
|
|
Caption('world', group=5),
|
|
Label(2, attributes={ 'x': 1, 'y': '2', }),
|
|
Bbox(1, 2, 3, 4, label=4, z_order=1, attributes={
|
|
'score': 1.0,
|
|
}),
|
|
Bbox(5, 6, 7, 8, group=5),
|
|
Bbox(5, 6, 7, 8, group=5), # unmatched
|
|
Points([1, 2, 2, 0, 1, 1], label=0, z_order=4),
|
|
Mask(label=3, z_order=2, image=np.ones((2, 3))),
|
|
]),
|
|
], categories=['a', 'b', 'c', 'd'])
|
|
|
|
comp = ExactComparator()
|
|
matched, unmatched, _, _, errors = comp.compare_datasets(a, b)
|
|
|
|
self.assertEqual(6, len(matched), matched)
|
|
self.assertEqual(2, len(unmatched), unmatched)
|
|
self.assertEqual(0, len(errors), errors)
|
|
|
|
def test_image_comparison(self):
|
|
a = Dataset.from_iterable([
|
|
DatasetItem(id=11, image=np.ones((5, 4, 3)), annotations=[
|
|
Bbox(5, 6, 7, 8),
|
|
]),
|
|
DatasetItem(id=12, image=np.ones((5, 4, 3)), annotations=[
|
|
Bbox(1, 2, 3, 4),
|
|
Bbox(5, 6, 7, 8),
|
|
]),
|
|
DatasetItem(id=13, image=np.ones((5, 4, 3)), annotations=[
|
|
Bbox(9, 10, 11, 12), # mismatch
|
|
]),
|
|
|
|
DatasetItem(id=14, image=np.zeros((5, 4, 3)), annotations=[
|
|
Bbox(1, 2, 3, 4),
|
|
Bbox(5, 6, 7, 8),
|
|
], attributes={ 'a': 1 }),
|
|
|
|
DatasetItem(id=15, image=np.zeros((5, 5, 3)), annotations=[
|
|
Bbox(1, 2, 3, 4),
|
|
Bbox(5, 6, 7, 8),
|
|
]),
|
|
], categories=['a', 'b', 'c', 'd'])
|
|
|
|
b = Dataset.from_iterable([
|
|
DatasetItem(id=21, image=np.ones((5, 4, 3)), annotations=[
|
|
Bbox(5, 6, 7, 8),
|
|
]),
|
|
DatasetItem(id=22, image=np.ones((5, 4, 3)), annotations=[
|
|
Bbox(1, 2, 3, 4),
|
|
Bbox(5, 6, 7, 8),
|
|
]),
|
|
DatasetItem(id=23, image=np.ones((5, 4, 3)), annotations=[
|
|
Bbox(10, 10, 11, 12), # mismatch
|
|
]),
|
|
|
|
DatasetItem(id=24, image=np.zeros((5, 4, 3)), annotations=[
|
|
Bbox(6, 6, 7, 8), # 1 ann missing, mismatch
|
|
], attributes={ 'a': 2 }),
|
|
|
|
DatasetItem(id=25, image=np.zeros((4, 4, 3)), annotations=[
|
|
Bbox(6, 6, 7, 8),
|
|
]),
|
|
], categories=['a', 'b', 'c', 'd'])
|
|
|
|
comp = ExactComparator(match_images=True)
|
|
matched_ann, unmatched_ann, a_unmatched, b_unmatched, errors = \
|
|
comp.compare_datasets(a, b)
|
|
|
|
self.assertEqual(3, len(matched_ann), matched_ann)
|
|
self.assertEqual(5, len(unmatched_ann), unmatched_ann)
|
|
self.assertEqual(1, len(a_unmatched), a_unmatched)
|
|
self.assertEqual(1, len(b_unmatched), b_unmatched)
|
|
self.assertEqual(1, len(errors), errors) |