|
|
|
|
@ -325,8 +325,12 @@ class RandomSplit(Transform, CliPlugin):
|
|
|
|
|
def __init__(self, extractor, splits, seed=None):
|
|
|
|
|
super().__init__(extractor)
|
|
|
|
|
|
|
|
|
|
total_ratio = sum((s[1] for s in splits), 0)
|
|
|
|
|
if not total_ratio == 1:
|
|
|
|
|
assert 0 < len(splits), "Expected at least one split"
|
|
|
|
|
assert all(0.0 <= r and r <= 1.0 for _, r in splits), \
|
|
|
|
|
"Ratios are expected to be in the range [0; 1], but got %s" % splits
|
|
|
|
|
|
|
|
|
|
total_ratio = sum(s[1] for s in splits)
|
|
|
|
|
if not abs(total_ratio - 1.0) <= 1e-7:
|
|
|
|
|
raise Exception(
|
|
|
|
|
"Sum of ratios is expected to be 1, got %s, which is %s" %
|
|
|
|
|
(splits, total_ratio))
|
|
|
|
|
@ -336,7 +340,6 @@ class RandomSplit(Transform, CliPlugin):
|
|
|
|
|
|
|
|
|
|
random.seed(seed)
|
|
|
|
|
random.shuffle(indices)
|
|
|
|
|
|
|
|
|
|
parts = []
|
|
|
|
|
s = 0
|
|
|
|
|
for subset, ratio in splits:
|
|
|
|
|
@ -350,7 +353,7 @@ class RandomSplit(Transform, CliPlugin):
|
|
|
|
|
for boundary, subset in self._parts:
|
|
|
|
|
if index < boundary:
|
|
|
|
|
return subset
|
|
|
|
|
return subset
|
|
|
|
|
return subset # all the possible remainder goes to the last split
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
for i, item in enumerate(self._extractor):
|
|
|
|
|
|