You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

122 lines
4.2 KiB
Python

# Copyright (C) 2019 Intel Corporation
#
# SPDX-License-Identifier: MIT
import inspect
import os
import os.path as osp
import shutil
import tempfile
from datumaro.components.extractor import AnnotationType
from datumaro.util import find
def current_function_name(depth=1):
return inspect.getouterframes(inspect.currentframe())[depth].function
class FileRemover:
def __init__(self, path, is_dir=False, ignore_errors=False):
self.path = path
self.is_dir = is_dir
self.ignore_errors = ignore_errors
def __enter__(self):
return self.path
# pylint: disable=redefined-builtin
def __exit__(self, type=None, value=None, traceback=None):
if self.is_dir:
shutil.rmtree(self.path, ignore_errors=self.ignore_errors)
else:
os.remove(self.path)
# pylint: enable=redefined-builtin
class TestDir(FileRemover):
def __init__(self, path=None, ignore_errors=False):
if path is None:
path = osp.abspath('temp_%s-' % current_function_name(2))
path = tempfile.mkdtemp(dir=os.getcwd(), prefix=path)
else:
os.makedirs(path, exist_ok=ignore_errors)
super().__init__(path, is_dir=True, ignore_errors=ignore_errors)
def compare_categories(test, expected, actual):
test.assertEqual(
sorted(expected, key=lambda t: t.value),
sorted(actual, key=lambda t: t.value)
)
if AnnotationType.label in expected:
test.assertEqual(
expected[AnnotationType.label].items,
actual[AnnotationType.label].items,
)
if AnnotationType.mask in expected:
test.assertEqual(
expected[AnnotationType.mask].colormap,
actual[AnnotationType.mask].colormap,
)
if AnnotationType.points in expected:
test.assertEqual(
expected[AnnotationType.points].items,
actual[AnnotationType.points].items,
)
def _compare_annotations(expected, actual, ignored_attrs=None):
if not ignored_attrs:
return expected == actual
a_attr = expected.attributes
b_attr = actual.attributes
expected.attributes = {k:v for k,v in a_attr.items() if k not in ignored_attrs}
actual.attributes = {k:v for k,v in b_attr.items() if k not in ignored_attrs}
r = expected == actual
expected.attributes = a_attr
actual.attributes = b_attr
return r
def compare_datasets(test, expected, actual, ignored_attrs=None):
compare_categories(test, expected.categories(), actual.categories())
test.assertEqual(sorted(expected.subsets()), sorted(actual.subsets()))
test.assertEqual(len(expected), len(actual))
for item_a in expected:
item_b = find(actual, lambda x: x.id == item_a.id and \
x.subset == item_a.subset)
test.assertFalse(item_b is None, item_a.id)
test.assertEqual(item_a.attributes, item_b.attributes)
test.assertEqual(len(item_a.annotations), len(item_b.annotations))
for ann_a in item_a.annotations:
# We might find few corresponding items, so check them all
ann_b_matches = [x for x in item_b.annotations
if x.type == ann_a.type]
test.assertFalse(len(ann_b_matches) == 0, 'ann id: %s' % ann_a.id)
ann_b = find(ann_b_matches, lambda x:
_compare_annotations(x, ann_a, ignored_attrs=ignored_attrs))
if ann_b is None:
test.assertEqual(ann_a, ann_b,
'ann %s, candidates %s' % (ann_a, ann_b_matches))
item_b.annotations.remove(ann_b) # avoid repeats
def compare_datasets_strict(test, expected, actual):
# Compares datasets for strong equality
test.assertEqual(expected.categories(), actual.categories())
test.assertListEqual(sorted(expected.subsets()), sorted(actual.subsets()))
test.assertEqual(len(expected), len(actual))
for subset_name in expected.subsets():
e_subset = expected.get_subset(subset_name)
a_subset = actual.get_subset(subset_name)
test.assertEqual(len(e_subset), len(a_subset))
for idx, (item_a, item_b) in enumerate(zip(e_subset, a_subset)):
test.assertEqual(item_a, item_b,
'%s:\n%s\nvs.\n%s\n' % \
(idx, item_a, item_b))