|
from collections import Counter |
|
|
|
import pytest |
|
|
|
from s3prl.dataio.sampler import BalancedWeightedSampler |
|
|
|
|
|
@pytest.mark.parametrize("duplicate", [10000, 100000]) |
|
def test_balanced_weighted_sampler(duplicate: int): |
|
labels = ["a", "a", "b", "a"] |
|
batch_size = 5 |
|
prev_diff_ratio = 1.0 |
|
sampler = BalancedWeightedSampler( |
|
labels, batch_size=batch_size, duplicate=duplicate, seed=0 |
|
) |
|
indices = list(sampler) |
|
assert len(indices[0]) == batch_size |
|
|
|
counter = Counter() |
|
for batch_indices in indices: |
|
for idx in batch_indices: |
|
counter.update(labels[idx]) |
|
|
|
diff_ratio = abs(counter["a"] - counter["b"]) / duplicate * len(labels) |
|
assert diff_ratio < prev_diff_ratio |
|
prev_diff_ratio = diff_ratio |
|
|
|
diff_ratio < 0.05 |
|
|