lmzjms's picture
Upload 1162 files
0b32ad6 verified
raw
history blame
4.44 kB
"""
The probing model following Hear Benchmark
Authors:
* Hear Team 2021
* Leo 2022
"""
from typing import List
import torch
import s3prl.nn.pooling as pooling
__all__ = ["HearFullyConnectedPrediction"]
class HearFullyConnectedPrediction(torch.nn.Module):
"""
The specific prediction head used in the Hear Benchmark.
Modified from: https://github.com/hearbenchmark/hear-eval-kit/blob/855964977238e89dfc76394aa11c37010edb6f20/heareval/predictions/task_predictions.py#L142
Args:
input_size (int): input_size
output_size (int): output_size
hidden_size (int): hidden size across all layers. Default: 1024
hidden_layers (int): number of hidden layers, all in :code:`hidden_size`. Default: 2
norm_after_activation (bool): whether to norm after activation. Default: False
dropout (float): dropout ratio. Default: 0.1
initialization (str): initialization method name available in :obj:`torch.nn.init`
hidden_norm (str): normalization method name available in :obj:`torch.nn`
pooling_type (str): the pooling class name in :obj:`s3prl.nn.pooling`. Default: MeanPooling
pooling_conf (dict): the arguments for initializing the pooling class.
Default: empty dict
"""
def __init__(
self,
input_size: int,
output_size: int,
hidden_size: int = 1024,
hidden_layers: int = 2,
norm_after_activation: bool = False,
dropout: float = 0.1,
initialization: str = "xavier_uniform_",
hidden_norm: str = "BatchNorm1d",
pooling_type: str = None,
pooling_conf: dict = None,
):
super().__init__()
self._input_size = input_size
self._output_size = output_size
initialization = getattr(torch.nn.init, initialization)
hidden_norm = getattr(torch.nn, hidden_norm)
curdim = input_size
if pooling_type is not None:
pooling_cls = getattr(pooling, pooling_type)
self.pooling = pooling_cls(input_size, **(pooling_conf or {}))
curdim = self.pooling.output_size
hidden_modules: List[torch.nn.Module] = []
last_activation = "linear"
if hidden_layers:
for i in range(hidden_layers):
linear = torch.nn.Linear(curdim, hidden_size)
initialization(
linear.weight,
gain=torch.nn.init.calculate_gain(last_activation),
)
hidden_modules.append(linear)
if not norm_after_activation:
hidden_modules.append(hidden_norm(hidden_size))
hidden_modules.append(torch.nn.Dropout(dropout))
hidden_modules.append(torch.nn.ReLU())
if norm_after_activation:
hidden_modules.append(hidden_norm(hidden_size))
curdim = hidden_size
last_activation = "relu"
self.hidden = torch.nn.Sequential(*hidden_modules)
else:
self.hidden = torch.nn.Identity() # type: ignore
self.projection = torch.nn.Linear(curdim, output_size)
initialization(
self.projection.weight, gain=torch.nn.init.calculate_gain(last_activation)
)
@property
def input_size(self) -> int:
return self._input_size
@property
def output_size(self) -> int:
return self._output_size
def forward(self, x, x_len) -> torch.Tensor:
"""
Args:
x (torch.FloatTensor): (batch_size, seq_len, input_size)
x_len (torch.LongTensor): (batch_size, )
Returns:
tuple:
1. y (torch.FloatTensor)
2. y_len (torch.LongTensor)
if :code:`pooling_type` is None, :code:`y` is (batch_size, seq_len, output_size) and :code:`y_len` is (batch_size, )
if not None, :code:`y` is (batch_size, output_size) and :code:`y_len` is (batch_size, ) in all 1s.
"""
if hasattr(self, "pooling"):
x = self.pooling(x, x_len)
x_len = x.new_ones(len(x))
shape = x.shape
if len(shape) == 3:
bs, ts, hidden_size = x.shape
x = x.reshape(bs * ts, hidden_size)
x = self.hidden(x)
x = self.projection(x)
if len(shape) == 3:
x = x.reshape(bs, ts, -1)
return x, x_len