[Datumaro] Mean and std for dataset (#1734)
* Add meanstd * Add stats cli * Update changelog Co-authored-by: Nikita Manovich <40690625+nmanovic@users.noreply.github.com>main
parent
3fee4cfcab
commit
12f78559d2
@ -0,0 +1,82 @@
|
||||
|
||||
# Copyright (C) 2020 Intel Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def mean_std(dataset):
|
||||
"""
|
||||
Computes unbiased mean and std. dev. for dataset images, channel-wise.
|
||||
"""
|
||||
# Use an online algorithm to:
|
||||
# - handle different image sizes
|
||||
# - avoid cancellation problem
|
||||
|
||||
stats = np.empty((len(dataset), 2, 3), dtype=np.double)
|
||||
counts = np.empty(len(dataset), dtype=np.uint32)
|
||||
|
||||
mean = lambda i, s: s[i][0]
|
||||
var = lambda i, s: s[i][1]
|
||||
|
||||
for i, item in enumerate(dataset):
|
||||
counts[i] = np.prod(item.image.size)
|
||||
|
||||
image = item.image.data
|
||||
if len(image.shape) == 2:
|
||||
image = image[:, :, np.newaxis]
|
||||
else:
|
||||
image = image[:, :, :3]
|
||||
# opencv is much faster than numpy here
|
||||
cv2.meanStdDev(image.astype(np.double) / 255,
|
||||
mean=mean(i, stats), stddev=var(i, stats))
|
||||
|
||||
# make variance unbiased
|
||||
np.multiply(np.square(stats[:, 1]),
|
||||
(counts / (counts - 1))[:, np.newaxis],
|
||||
out=stats[:, 1])
|
||||
|
||||
_, mean, var = StatsCounter().compute_stats(stats, counts, mean, var)
|
||||
return mean * 255, np.sqrt(var) * 255
|
||||
|
||||
class StatsCounter:
|
||||
# Implements online parallel computation of sample variance
|
||||
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||
|
||||
# Needed do avoid catastrophic cancellation in floating point computations
|
||||
@staticmethod
|
||||
def pairwise_stats(count_a, mean_a, var_a, count_b, mean_b, var_b):
|
||||
delta = mean_b - mean_a
|
||||
m_a = var_a * (count_a - 1)
|
||||
m_b = var_b * (count_b - 1)
|
||||
M2 = m_a + m_b + delta ** 2 * count_a * count_b / (count_a + count_b)
|
||||
return (
|
||||
count_a + count_b,
|
||||
mean_a * 0.5 + mean_b * 0.5,
|
||||
M2 / (count_a + count_b - 1)
|
||||
)
|
||||
|
||||
# stats = float array of shape N, 2 * d, d = dimensions of values
|
||||
# count = integer array of shape N
|
||||
# mean_accessor = function(idx, stats) to retrieve element mean
|
||||
# variance_accessor = function(idx, stats) to retrieve element variance
|
||||
# Recursively computes total count, mean and variance, does O(log(N)) calls
|
||||
@staticmethod
|
||||
def compute_stats(stats, counts, mean_accessor, variance_accessor):
|
||||
m = mean_accessor
|
||||
v = variance_accessor
|
||||
n = len(stats)
|
||||
if n == 1:
|
||||
return counts[0], m(0, stats), v(0, stats)
|
||||
if n == 2:
|
||||
return __class__.pairwise_stats(
|
||||
counts[0], m(0, stats), v(0, stats),
|
||||
counts[1], m(1, stats), v(1, stats)
|
||||
)
|
||||
h = n // 2
|
||||
return __class__.pairwise_stats(
|
||||
*__class__.compute_stats(stats[:h], counts[:h], m, v),
|
||||
*__class__.compute_stats(stats[h:], counts[h:], m, v)
|
||||
)
|
||||
@ -0,0 +1,31 @@
|
||||
import numpy as np
|
||||
|
||||
from datumaro.components.extractor import Extractor, DatasetItem
|
||||
from datumaro.components.operations import mean_std
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
|
||||
class TestOperations(TestCase):
|
||||
def test_mean_std(self):
|
||||
expected_mean = [100, 50, 150]
|
||||
expected_std = [20, 50, 10]
|
||||
|
||||
class TestExtractor(Extractor):
|
||||
def __iter__(self):
|
||||
return iter([
|
||||
DatasetItem(id=1, image=np.random.normal(
|
||||
expected_mean, expected_std,
|
||||
size=(w, h, 3))
|
||||
)
|
||||
for i, (w, h) in enumerate([
|
||||
(3000, 100), (800, 600), (400, 200), (700, 300)
|
||||
])
|
||||
])
|
||||
|
||||
actual_mean, actual_std = mean_std(TestExtractor())
|
||||
|
||||
for em, am in zip(expected_mean, actual_mean):
|
||||
self.assertAlmostEqual(em, am, places=0)
|
||||
for estd, astd in zip(expected_std, actual_std):
|
||||
self.assertAlmostEqual(estd, astd, places=0)
|
||||
Loading…
Reference in New Issue