File size: 4,441 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""
The probing model following Hear Benchmark
Authors:
* Hear Team 2021
* Leo 2022
"""
from typing import List
import torch
import s3prl.nn.pooling as pooling
__all__ = ["HearFullyConnectedPrediction"]
class HearFullyConnectedPrediction(torch.nn.Module):
"""
The specific prediction head used in the Hear Benchmark.
Modified from: https://github.com/hearbenchmark/hear-eval-kit/blob/855964977238e89dfc76394aa11c37010edb6f20/heareval/predictions/task_predictions.py#L142
Args:
input_size (int): input_size
output_size (int): output_size
hidden_size (int): hidden size across all layers. Default: 1024
hidden_layers (int): number of hidden layers, all in :code:`hidden_size`. Default: 2
norm_after_activation (bool): whether to norm after activation. Default: False
dropout (float): dropout ratio. Default: 0.1
initialization (str): initialization method name available in :obj:`torch.nn.init`
hidden_norm (str): normalization method name available in :obj:`torch.nn`
pooling_type (str): the pooling class name in :obj:`s3prl.nn.pooling`. Default: MeanPooling
pooling_conf (dict): the arguments for initializing the pooling class.
Default: empty dict
"""
def __init__(
self,
input_size: int,
output_size: int,
hidden_size: int = 1024,
hidden_layers: int = 2,
norm_after_activation: bool = False,
dropout: float = 0.1,
initialization: str = "xavier_uniform_",
hidden_norm: str = "BatchNorm1d",
pooling_type: str = None,
pooling_conf: dict = None,
):
super().__init__()
self._input_size = input_size
self._output_size = output_size
initialization = getattr(torch.nn.init, initialization)
hidden_norm = getattr(torch.nn, hidden_norm)
curdim = input_size
if pooling_type is not None:
pooling_cls = getattr(pooling, pooling_type)
self.pooling = pooling_cls(input_size, **(pooling_conf or {}))
curdim = self.pooling.output_size
hidden_modules: List[torch.nn.Module] = []
last_activation = "linear"
if hidden_layers:
for i in range(hidden_layers):
linear = torch.nn.Linear(curdim, hidden_size)
initialization(
linear.weight,
gain=torch.nn.init.calculate_gain(last_activation),
)
hidden_modules.append(linear)
if not norm_after_activation:
hidden_modules.append(hidden_norm(hidden_size))
hidden_modules.append(torch.nn.Dropout(dropout))
hidden_modules.append(torch.nn.ReLU())
if norm_after_activation:
hidden_modules.append(hidden_norm(hidden_size))
curdim = hidden_size
last_activation = "relu"
self.hidden = torch.nn.Sequential(*hidden_modules)
else:
self.hidden = torch.nn.Identity() # type: ignore
self.projection = torch.nn.Linear(curdim, output_size)
initialization(
self.projection.weight, gain=torch.nn.init.calculate_gain(last_activation)
)
@property
def input_size(self) -> int:
return self._input_size
@property
def output_size(self) -> int:
return self._output_size
def forward(self, x, x_len) -> torch.Tensor:
"""
Args:
x (torch.FloatTensor): (batch_size, seq_len, input_size)
x_len (torch.LongTensor): (batch_size, )
Returns:
tuple:
1. y (torch.FloatTensor)
2. y_len (torch.LongTensor)
if :code:`pooling_type` is None, :code:`y` is (batch_size, seq_len, output_size) and :code:`y_len` is (batch_size, )
if not None, :code:`y` is (batch_size, output_size) and :code:`y_len` is (batch_size, ) in all 1s.
"""
if hasattr(self, "pooling"):
x = self.pooling(x, x_len)
x_len = x.new_ones(len(x))
shape = x.shape
if len(shape) == 3:
bs, ts, hidden_size = x.shape
x = x.reshape(bs * ts, hidden_size)
x = self.hidden(x)
x = self.projection(x)
if len(shape) == 3:
x = x.reshape(bs, ts, -1)
return x, x_len
|