|
import logging |
|
import random |
|
from collections import OrderedDict |
|
|
|
from s3prl.dataio.sampler import SortedBucketingSampler, SortedSliceSampler |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def test_sorted_slice_sampler(): |
|
batch_size = 16 |
|
max_length = 16000 * 5 |
|
lengths = [random.randint(16000 * 3, 16000 * 8) for index in range(1000)] |
|
|
|
sampler = SortedSliceSampler( |
|
lengths, |
|
batch_size=batch_size, |
|
max_length=max_length, |
|
) |
|
|
|
for epoch in range(5): |
|
sampler.set_epoch(epoch) |
|
id2length = lengths |
|
for batch_ids in sampler: |
|
batch_lengths = [id2length[idx] for idx in batch_ids] |
|
assert sorted(batch_lengths, reverse=True) == batch_lengths |
|
if batch_lengths[0] > max_length: |
|
assert len(batch_lengths) == batch_size // 2 |
|
|
|
other_batch_sizes = [ |
|
len(batch) |
|
for batch in sampler |
|
if len(batch) not in [batch_size, batch_size // 2] |
|
] |
|
assert len(set(other_batch_sizes)) == len(other_batch_sizes) |
|
assert len(sampler) == len(lengths) |
|
|
|
|
|
def test_sorted_bucketing_sampler(): |
|
batch_size = 16 |
|
max_length = 16000 * 5 |
|
lengths = [random.randint(16000 * 3, 16000 * 8) for index in range(1000)] |
|
|
|
sampler = SortedBucketingSampler( |
|
lengths, |
|
batch_size=batch_size, |
|
max_length=max_length, |
|
shuffle=False, |
|
) |
|
|
|
for epoch in range(5): |
|
sampler.set_epoch(epoch) |
|
id2length = lengths |
|
for batch_ids in sampler: |
|
batch_lengths = [id2length[idx] for idx in batch_ids] |
|
assert sorted(batch_lengths, reverse=True) == batch_lengths |
|
if batch_lengths[0] > max_length: |
|
assert len(batch_lengths) == batch_size // 2 |
|
|
|
batch_sizes = [len(batch_indices) for batch_indices in sampler] |
|
other_batch_sizes = [ |
|
batch_size |
|
for batch_size in batch_sizes |
|
if batch_size not in [batch_size, batch_size // 2] |
|
] |
|
assert len(other_batch_sizes) <= 1 |
|
assert len(lengths) / 16 < len(sampler) < len(lengths) / 8 |
|
|