from torch import nn from s3prl import Output from s3prl.nn.transformer_mockingjay import ( ACT2FN, TransformerConfig, TransformerLayerNorm, ) class PredictorMockingjay(nn.Module): """ The predictor model for SSL pre-training tasks. Currently supporting SSL problems of Mockingjay, Tera, and Audio Albert. """ def __init__(self, config, output_dim, input_dim=None, **kwargs): """ Args: config (TransformerConfig): A `TransformerConfig` class instance with the configuration to build a new model, can also be a `dict` that initializes the TransformerConfig class output_dim (int): The output dimension of predictor input_dim (int): The input dimension of predictor, if `None` is given, then use the `hidden_size` defined in `config`. Default: None """ super(PredictorMockingjay, self).__init__() if type(config) is dict: config = TransformerConfig(**config) self.output_size = output_dim if input_dim is None: self.dense = nn.Linear(config.hidden_size, config.hidden_size) else: self.dense = nn.Linear(input_dim, config.hidden_size) if isinstance(config.hidden_act, str): self.transform_act_fn = ACT2FN[config.hidden_act] else: self.transform_act_fn = config.hidden_act self.LayerNorm = TransformerLayerNorm( config.hidden_size, eps=config.layer_norm_eps ) self.output = nn.Linear(config.hidden_size, self.output_size) def forward(self, inputs, output_states=False): """ Args: inputs (torch.LongTensor): A torch.LongTensor of shape [batch_size, sequence_length, input_dim] output_states (bool): A boolean which controls whether to return the `hidden_states` of the predictor. Default: False Return: Output (s3prl.Output): An Output module that contains `prediction` and/or `hidden_states`. """ hidden_states = inputs.hidden_states hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) prediction = self.output(hidden_states) if output_states: return Output(hidden_states=hidden_states, prediction=prediction) else: return Output(prediction=prediction)