lmzjms's picture
Upload 1162 files
0b32ad6 verified
raw
history blame
642 Bytes
import pytest
import torch
from s3prl.nn.common import UtteranceLevel
from s3prl.nn.pooling import (
AttentiveStatisticsPooling,
MeanPooling,
SelfAttentivePooling,
TemporalStatisticsPooling,
)
@pytest.mark.parametrize(
"pooling_type",
[
"MeanPooling",
"TemporalStatisticsPooling",
"AttentiveStatisticsPooling",
"SelfAttentivePooling",
],
)
def test_utterance_level_with_pooling(pooling_type: str):
model = UtteranceLevel(256, 64, [128], "ReLU", None, pooling_type, None)
output = model(torch.randn(32, 100, 256), torch.arange(32) + 1)
assert output.shape == (32, 64)