|
from torch import nn |
|
|
|
from s3prl import Output |
|
from s3prl.nn.transformer_mockingjay import ( |
|
ACT2FN, |
|
TransformerConfig, |
|
TransformerLayerNorm, |
|
) |
|
|
|
|
|
class PredictorMockingjay(nn.Module): |
|
""" |
|
The predictor model for SSL pre-training tasks. |
|
Currently supporting SSL problems of Mockingjay, Tera, and Audio Albert. |
|
""" |
|
|
|
def __init__(self, config, output_dim, input_dim=None, **kwargs): |
|
""" |
|
Args: |
|
config (TransformerConfig): |
|
A `TransformerConfig` class instance with the configuration to build a new model, |
|
can also be a `dict` that initializes the TransformerConfig class |
|
output_dim (int): |
|
The output dimension of predictor |
|
input_dim (int): |
|
The input dimension of predictor, if `None` is given, then use the `hidden_size` defined in `config`. |
|
Default: None |
|
""" |
|
|
|
super(PredictorMockingjay, self).__init__() |
|
if type(config) is dict: |
|
config = TransformerConfig(**config) |
|
self.output_size = output_dim |
|
if input_dim is None: |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
else: |
|
self.dense = nn.Linear(input_dim, config.hidden_size) |
|
if isinstance(config.hidden_act, str): |
|
self.transform_act_fn = ACT2FN[config.hidden_act] |
|
else: |
|
self.transform_act_fn = config.hidden_act |
|
self.LayerNorm = TransformerLayerNorm( |
|
config.hidden_size, eps=config.layer_norm_eps |
|
) |
|
self.output = nn.Linear(config.hidden_size, self.output_size) |
|
|
|
def forward(self, inputs, output_states=False): |
|
""" |
|
Args: |
|
inputs (torch.LongTensor): |
|
A torch.LongTensor of shape [batch_size, sequence_length, input_dim] |
|
output_states (bool): |
|
A boolean which controls whether to return the `hidden_states` of the predictor. |
|
Default: False |
|
Return: |
|
Output (s3prl.Output): |
|
An Output module that contains `prediction` and/or `hidden_states`. |
|
""" |
|
hidden_states = inputs.hidden_states |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.transform_act_fn(hidden_states) |
|
hidden_states = self.LayerNorm(hidden_states) |
|
prediction = self.output(hidden_states) |
|
if output_states: |
|
return Output(hidden_states=hidden_states, prediction=prediction) |
|
else: |
|
return Output(prediction=prediction) |
|
|