Fix remainder logic for subset splitting (#1222)

main
zhiltsov-max 6 years ago committed by GitHub
parent 97195238b0
commit 0d873a3de6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -325,8 +325,12 @@ class RandomSplit(Transform, CliPlugin):
def __init__(self, extractor, splits, seed=None):
super().__init__(extractor)
total_ratio = sum((s[1] for s in splits), 0)
if not total_ratio == 1:
assert 0 < len(splits), "Expected at least one split"
assert all(0.0 <= r and r <= 1.0 for _, r in splits), \
"Ratios are expected to be in the range [0; 1], but got %s" % splits
total_ratio = sum(s[1] for s in splits)
if not abs(total_ratio - 1.0) <= 1e-7:
raise Exception(
"Sum of ratios is expected to be 1, got %s, which is %s" %
(splits, total_ratio))
@ -336,7 +340,6 @@ class RandomSplit(Transform, CliPlugin):
random.seed(seed)
random.shuffle(indices)
parts = []
s = 0
for subset, ratio in splits:
@ -350,7 +353,7 @@ class RandomSplit(Transform, CliPlugin):
for boundary, subset in self._parts:
if index < boundary:
return subset
return subset
return subset # all the possible remainder goes to the last split
def __iter__(self):
for i, item in enumerate(self._extractor):

@ -534,15 +534,10 @@ class DatasetTest(TestCase):
class DatasetItemTest(TestCase):
def test_ctor_requires_id(self):
has_error = False
try:
with self.assertRaises(Exception):
# pylint: disable=no-value-for-parameter
DatasetItem()
# pylint: enable=no-value-for-parameter
except AssertionError:
has_error = True
self.assertTrue(has_error)
@staticmethod
def test_ctors_with_image():

@ -342,18 +342,22 @@ class TransformsTest(TestCase):
self.assertEqual(4, len(actual.get_subset('train')))
self.assertEqual(3, len(actual.get_subset('test')))
def test_random_split_gives_error_on_non1_ratios(self):
def test_random_split_gives_error_on_wrong_ratios(self):
class SrcExtractor(Extractor):
def __iter__(self):
return iter([DatasetItem(id=1)])
has_error = False
try:
with self.assertRaises(Exception):
transforms.RandomSplit(SrcExtractor(), splits=[
('train', 0.5),
('test', 0.7),
])
except Exception:
has_error = True
self.assertTrue(has_error)
with self.assertRaises(Exception):
transforms.RandomSplit(SrcExtractor(), splits=[])
with self.assertRaises(Exception):
transforms.RandomSplit(SrcExtractor(), splits=[
('train', -0.5),
('test', 1.5),
])

Loading…
Cancel
Save