You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
230 lines
8.4 KiB
Python
230 lines
8.4 KiB
Python
from collections import namedtuple
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from unittest import TestCase
|
|
|
|
from datumaro.components.extractor import LabelObject, BboxObject
|
|
from datumaro.components.launcher import Launcher
|
|
from datumaro.components.algorithms.rise import RISE
|
|
|
|
|
|
class RiseTest(TestCase):
|
|
def test_rise_can_be_applied_to_classification_model(self):
|
|
class TestLauncher(Launcher):
|
|
def __init__(self, class_count, roi, **kwargs):
|
|
self.class_count = class_count
|
|
self.roi = roi
|
|
|
|
def launch(self, inputs):
|
|
for inp in inputs:
|
|
yield self._process(inp)
|
|
|
|
def _process(self, image):
|
|
roi = self.roi
|
|
roi_area = (roi[1] - roi[0]) * (roi[3] - roi[2])
|
|
if 0.5 * roi_area < np.sum(image[roi[0]:roi[1], roi[2]:roi[3], 0]):
|
|
cls = 0
|
|
else:
|
|
cls = 1
|
|
|
|
cls_conf = 0.5
|
|
other_conf = (1.0 - cls_conf) / (self.class_count - 1)
|
|
|
|
return [
|
|
LabelObject(i, attributes={
|
|
'score': cls_conf if cls == i else other_conf }) \
|
|
for i in range(self.class_count)
|
|
]
|
|
|
|
roi = [70, 90, 7, 90]
|
|
model = TestLauncher(class_count=3, roi=roi)
|
|
|
|
rise = RISE(model, max_samples=(7 * 7) ** 2, mask_width=7, mask_height=7)
|
|
|
|
image = np.ones((100, 100, 3))
|
|
heatmaps = next(rise.apply(image))
|
|
|
|
self.assertEqual(1, len(heatmaps))
|
|
|
|
heatmap = heatmaps[0]
|
|
self.assertEqual(image.shape[:2], heatmap.shape)
|
|
|
|
h_sum = np.sum(heatmap)
|
|
h_area = np.prod(heatmap.shape)
|
|
roi_sum = np.sum(heatmap[roi[0]:roi[1], roi[2]:roi[3]])
|
|
roi_area = (roi[1] - roi[0]) * (roi[3] - roi[2])
|
|
roi_den = roi_sum / roi_area
|
|
hrest_den = (h_sum - roi_sum) / (h_area - roi_area)
|
|
self.assertLess(hrest_den, roi_den)
|
|
|
|
def test_rise_can_be_applied_to_detection_model(self):
|
|
ROI = namedtuple('ROI',
|
|
['threshold', 'x', 'y', 'w', 'h', 'label'])
|
|
|
|
class TestLauncher(Launcher):
|
|
def __init__(self, rois, class_count, fp_count=4, pixel_jitter=20, **kwargs):
|
|
self.rois = rois
|
|
self.roi_base_sums = [None, ] * len(rois)
|
|
self.class_count = class_count
|
|
self.fp_count = fp_count
|
|
self.pixel_jitter = pixel_jitter
|
|
|
|
@staticmethod
|
|
def roi_value(roi, image):
|
|
return np.sum(
|
|
image[roi.y:roi.y + roi.h, roi.x:roi.x + roi.w, :])
|
|
|
|
def launch(self, inputs):
|
|
for inp in inputs:
|
|
yield self._process(inp)
|
|
|
|
def _process(self, image):
|
|
detections = []
|
|
for i, roi in enumerate(self.rois):
|
|
roi_sum = self.roi_value(roi, image)
|
|
roi_base_sum = self.roi_base_sums[i]
|
|
first_run = roi_base_sum is None
|
|
if first_run:
|
|
roi_base_sum = roi_sum
|
|
self.roi_base_sums[i] = roi_base_sum
|
|
|
|
cls_conf = roi_sum / roi_base_sum
|
|
|
|
if roi.threshold < roi_sum / roi_base_sum:
|
|
cls = roi.label
|
|
detections.append(
|
|
BboxObject(roi.x, roi.y, roi.w, roi.h,
|
|
label=cls, attributes={'score': cls_conf})
|
|
)
|
|
|
|
if first_run:
|
|
continue
|
|
for j in range(self.fp_count):
|
|
if roi.threshold < cls_conf:
|
|
cls = roi.label
|
|
else:
|
|
cls = (i + j) % self.class_count
|
|
box = [roi.x, roi.y, roi.w, roi.h]
|
|
offset = (np.random.rand(4) - 0.5) * self.pixel_jitter
|
|
detections.append(
|
|
BboxObject(*(box + offset),
|
|
label=cls, attributes={'score': cls_conf})
|
|
)
|
|
|
|
return detections
|
|
|
|
rois = [
|
|
ROI(0.3, 10, 40, 30, 10, 0),
|
|
ROI(0.5, 70, 90, 7, 10, 0),
|
|
ROI(0.7, 5, 20, 40, 60, 2),
|
|
ROI(0.9, 30, 20, 10, 40, 1),
|
|
]
|
|
model = model = TestLauncher(class_count=3, rois=rois)
|
|
|
|
rise = RISE(model, max_samples=(7 * 7) ** 2, mask_width=7, mask_height=7)
|
|
|
|
image = np.ones((100, 100, 3))
|
|
heatmaps = next(rise.apply(image))
|
|
heatmaps_class_count = len(set([roi.label for roi in rois]))
|
|
self.assertEqual(heatmaps_class_count + len(rois), len(heatmaps))
|
|
|
|
# roi_image = image.copy()
|
|
# for i, roi in enumerate(rois):
|
|
# cv2.rectangle(roi_image, (roi.x, roi.y), (roi.x + roi.w, roi.y + roi.h), (32 * i) * 3)
|
|
# cv2.imshow('img', roi_image)
|
|
|
|
for c in range(heatmaps_class_count):
|
|
class_roi = np.zeros(image.shape[:2])
|
|
for i, roi in enumerate(rois):
|
|
if roi.label != c:
|
|
continue
|
|
class_roi[roi.y:roi.y + roi.h, roi.x:roi.x + roi.w] \
|
|
+= roi.threshold
|
|
|
|
heatmap = heatmaps[c]
|
|
|
|
roi_pixels = heatmap[class_roi != 0]
|
|
h_sum = np.sum(roi_pixels)
|
|
h_area = np.sum(roi_pixels != 0)
|
|
h_den = h_sum / h_area
|
|
|
|
rest_pixels = heatmap[class_roi == 0]
|
|
r_sum = np.sum(rest_pixels)
|
|
r_area = np.sum(rest_pixels != 0)
|
|
r_den = r_sum / r_area
|
|
|
|
# print(r_den, h_den)
|
|
# cv2.imshow('class %s' % c, heatmap)
|
|
self.assertLess(r_den, h_den)
|
|
|
|
for i, roi in enumerate(rois):
|
|
heatmap = heatmaps[heatmaps_class_count + i]
|
|
h_sum = np.sum(heatmap)
|
|
h_area = np.prod(heatmap.shape)
|
|
roi_sum = np.sum(heatmap[roi.y:roi.y + roi.h, roi.x:roi.x + roi.w])
|
|
roi_area = roi.h * roi.w
|
|
roi_den = roi_sum / roi_area
|
|
hrest_den = (h_sum - roi_sum) / (h_area - roi_area)
|
|
# print(hrest_den, h_den)
|
|
# cv2.imshow('roi %s' % i, heatmap)
|
|
self.assertLess(hrest_den, roi_den)
|
|
# cv2.waitKey(0)
|
|
|
|
@staticmethod
|
|
def DISABLED_test_roi_nms():
|
|
ROI = namedtuple('ROI',
|
|
['conf', 'x', 'y', 'w', 'h', 'label'])
|
|
|
|
class_count = 3
|
|
noisy_count = 3
|
|
rois = [
|
|
ROI(0.3, 10, 40, 30, 10, 0),
|
|
ROI(0.5, 70, 90, 7, 10, 0),
|
|
ROI(0.7, 5, 20, 40, 60, 2),
|
|
ROI(0.9, 30, 20, 10, 40, 1),
|
|
]
|
|
pixel_jitter = 10
|
|
|
|
detections = []
|
|
for i, roi in enumerate(rois):
|
|
detections.append(
|
|
BboxObject(roi.x, roi.y, roi.w, roi.h,
|
|
label=roi.label, attributes={'score': roi.conf})
|
|
)
|
|
|
|
for j in range(noisy_count):
|
|
cls_conf = roi.conf * j / noisy_count
|
|
cls = (i + j) % class_count
|
|
box = [roi.x, roi.y, roi.w, roi.h]
|
|
offset = (np.random.rand(4) - 0.5) * pixel_jitter
|
|
detections.append(
|
|
BboxObject(*(box + offset),
|
|
label=cls, attributes={'score': cls_conf})
|
|
)
|
|
|
|
image = np.zeros((100, 100, 3))
|
|
for i, det in enumerate(detections):
|
|
roi = ROI(det.attributes['score'], *det.get_bbox(), det.label)
|
|
p1 = (int(roi.x), int(roi.y))
|
|
p2 = (int(roi.x + roi.w), int(roi.y + roi.h))
|
|
c = (0, 1 * (i % (1 + noisy_count) == 0), 1)
|
|
cv2.rectangle(image, p1, p2, c)
|
|
cv2.putText(image, 'd%s-%s-%.2f' % (i, roi.label, roi.conf),
|
|
p1, cv2.FONT_HERSHEY_SIMPLEX, 0.25, c)
|
|
cv2.imshow('nms_image', image)
|
|
cv2.waitKey(0)
|
|
|
|
nms_boxes = RISE.nms(detections, iou_thresh=0.25)
|
|
print(len(detections), len(nms_boxes))
|
|
|
|
for i, det in enumerate(nms_boxes):
|
|
roi = ROI(det.attributes['score'], *det.get_bbox(), det.label)
|
|
p1 = (int(roi.x), int(roi.y))
|
|
p2 = (int(roi.x + roi.w), int(roi.y + roi.h))
|
|
c = (0, 1, 0)
|
|
cv2.rectangle(image, p1, p2, c)
|
|
cv2.putText(image, 'p%s-%s-%.2f' % (i, roi.label, roi.conf),
|
|
p1, cv2.FONT_HERSHEY_SIMPLEX, 0.25, c)
|
|
cv2.imshow('nms_image', image)
|
|
cv2.waitKey(0) |