Add random split transform (#1213)

main
zhiltsov-max 6 years ago committed by GitHub
parent 3604f0c5ce
commit da69a40b96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,6 +5,7 @@
import logging as log
import os.path as osp
import random
import pycocotools.mask as mask_utils
@ -295,6 +296,66 @@ class MapSubsets(Transform, CliPlugin):
return self.wrap_item(item,
subset=self._mapping.get(item.subset, item.subset))
class RandomSplit(Transform, CliPlugin):
"""
Joins all subsets into one and splits the result into few parts.
It is expected that item ids are unique and subset ratios sum up to 1.|n
|n
Example:|n
|s|s%(prog)s --subset train:.67 --subset test:.33
"""
@staticmethod
def _split_arg(s):
parts = s.split(':')
if len(parts) != 2:
import argparse
raise argparse.ArgumentTypeError()
return (parts[0], float(parts[1]))
@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('-s', '--subset', action='append',
type=cls._split_arg, dest='splits',
help="Subsets in the form of: '<subset>:<ratio>' (repeatable)")
parser.add_argument('--seed', type=int, help="Random seed")
return parser
def __init__(self, extractor, splits, seed=None):
super().__init__(extractor)
total_ratio = sum((s[1] for s in splits), 0)
if not total_ratio == 1:
raise Exception(
"Sum of ratios is expected to be 1, got %s, which is %s" %
(splits, total_ratio))
dataset_size = len(extractor)
indices = list(range(dataset_size))
random.seed(seed)
random.shuffle(indices)
parts = []
s = 0
for subset, ratio in splits:
s += ratio
boundary = int(s * dataset_size)
parts.append((boundary, subset))
self._parts = parts
def _find_split(self, index):
for boundary, subset in self._parts:
if index < boundary:
return subset
return subset
def __iter__(self):
for i, item in enumerate(self._extractor):
yield self.wrap_item(item, subset=self._find_split(i))
class IdFromImageName(Transform, CliPlugin):
def transform_item(self, item):
name = item.id

@ -320,3 +320,40 @@ class TransformsTest(TestCase):
actual = transforms.BoxesToMasks(SrcExtractor())
compare_datasets(self, DstExtractor(), actual)
def test_random_split(self):
class SrcExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=1, subset="a"),
DatasetItem(id=2, subset="a"),
DatasetItem(id=3, subset="b"),
DatasetItem(id=4, subset="b"),
DatasetItem(id=5, subset="b"),
DatasetItem(id=6, subset=""),
DatasetItem(id=7, subset=""),
])
actual = transforms.RandomSplit(SrcExtractor(), splits=[
('train', 4.0 / 7.0),
('test', 3.0 / 7.0),
])
self.assertEqual(4, len(actual.get_subset('train')))
self.assertEqual(3, len(actual.get_subset('test')))
def test_random_split_gives_error_on_non1_ratios(self):
class SrcExtractor(Extractor):
def __iter__(self):
return iter([DatasetItem(id=1)])
has_error = False
try:
transforms.RandomSplit(SrcExtractor(), splits=[
('train', 0.5),
('test', 0.7),
])
except Exception:
has_error = True
self.assertTrue(has_error)
Loading…
Cancel
Save