wavlm-large / s3prl_s3prl_main /test /test_balanced_weighted_sampler.py
lmzjms's picture
Upload 1162 files
0b32ad6 verified
raw
history blame
780 Bytes
from collections import Counter
import pytest
from s3prl.dataio.sampler import BalancedWeightedSampler
@pytest.mark.parametrize("duplicate", [10000, 100000])
def test_balanced_weighted_sampler(duplicate: int):
labels = ["a", "a", "b", "a"]
batch_size = 5
prev_diff_ratio = 1.0
sampler = BalancedWeightedSampler(
labels, batch_size=batch_size, duplicate=duplicate, seed=0
)
indices = list(sampler)
assert len(indices[0]) == batch_size
counter = Counter()
for batch_indices in indices:
for idx in batch_indices:
counter.update(labels[idx])
diff_ratio = abs(counter["a"] - counter["b"]) / duplicate * len(labels)
assert diff_ratio < prev_diff_ratio
prev_diff_ratio = diff_ratio
diff_ratio < 0.05