bert-base-uncased-cls / modeling_bert.py
yangwang825's picture
Upload BertForSequenceClassification
aff5c65
raw
history blame
16.8 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from typing import Optional, List, Union, Tuple
from transformers import (
PretrainedConfig,
PreTrainedModel,
AutoTokenizer,
AutoConfig,
AutoModel,
AutoModelForSequenceClassification
)
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
BertEncoder,
load_tf_weights_in_bert
)
from transformers.modeling_outputs import (
BaseModelOutputWithPoolingAndCrossAttentions,
SequenceClassifierOutput
)
from .configuration_bert import BertConfig
class BertPreTrainedModel(PreTrainedModel):
config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class PFSA(nn.Module):
"""
https://openreview.net/pdf?id=isodM5jTA7h
"""
def __init__(self, input_dim, alpha=1):
super(PFSA, self).__init__()
self.input_dim = input_dim
self.alpha = alpha
def forward(self, x, mask=None):
"""
x: [B, T, F]
"""
x = x.transpose(1, 2)[..., None]
k = torch.mean(x, dim=[-1, -2], keepdim=True)
kd = torch.sqrt((k - k.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, 1, 1]
qd = torch.sqrt((x - x.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, T, 1]
C_qk = (((x - x.mean(dim=1, keepdim=True)) * (k - k.mean(dim=1, keepdim=True))).sum(dim=1, keepdim=True)) / (qd * kd)
A = (1 - torch.sigmoid(C_qk)) ** self.alpha
out = x * A
out = out.squeeze(dim=-1).transpose(1, 2)
return out
class PURE(nn.Module):
def __init__(
self,
in_dim,
q=5,
r=1,
center=False,
num_iters=1,
return_mean=True,
return_std=True,
normalize=False,
do_pcr=True,
do_pfsa=True,
alpha=1,
*args, **kwargs
):
super().__init__()
self.in_dim = in_dim
self.target_rank = q
self.num_pc_to_remove = r
self.center = center
self.num_iters = num_iters
self.return_mean = return_mean
self.return_std = return_std
self.normalize = normalize
self.do_pcr = do_pcr
self.do_pfsa = do_pfsa
# self.attention = SelfAttention(in_dim)
self.attention = PFSA(in_dim, alpha=alpha)
self.eps = 1e-5
if self.normalize:
self.norm = nn.Sequential(OrderedDict([
('relu', nn.LeakyReLU(inplace=True)),
('bn', nn.BatchNorm1d(in_dim)),
]))
def get_out_dim(self):
if self.return_mean and self.return_std:
self.out_dim = self.in_dim * 2
else:
self.out_dim = self.in_dim
return self.out_dim
def _compute_pc(self, x):
"""
x: (B, T, F)
"""
_, _, V = torch.pca_lowrank(x, q=self.target_rank, center=self.center, niter=self.num_iters)
pc = V.transpose(1, 2)[:, :self.num_pc_to_remove, :] # pc: [B, K, F]
return pc
def forward(self, x, attention_mask=None, *args, **kwargs):
"""
PCR -> Attention
x: (B, F, T)
"""
if self.normalize:
x = self.norm(x)
xt = x.transpose(1, 2)
if self.do_pcr:
pc = self._compute_pc(xt) # pc: [B, K, F]
xx = xt - xt @ pc.transpose(1, 2) @ pc # [B, T, F] * [B, F, K] * [B, K, F] = [B, T, F]
else:
xx = xt
if self.do_pfsa:
xx = self.attention(xx, attention_mask)
if self.normalize:
xx = F.normalize(xx, p=2, dim=2)
return xx
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.pure = PURE(
config.hidden_size,
q=config.q,
r=config.r,
center=config.center,
num_iters=config.num_iters,
return_mean=config.return_mean,
return_std=config.return_std,
normalize=config.normalize,
do_pcr=config.do_pcr,
do_pfsa=config.do_pfsa,
alpha=config.alpha
)
if config.affine:
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
else:
self.dense = nn.Identity()
self.activation = nn.Tanh()
self.eps = 1e-5
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
hidden_states = self.pure(hidden_states.transpose(1, 2), attention_mask)
mean_tensor = self.mean_pooling(hidden_states, attention_mask)
pooled_output = self.dense(mean_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
def _get_gauss_noise(self, shape_of_tensor, device="cpu"):
"""Returns a tensor of epsilon Gaussian noise.
Arguments
---------
shape_of_tensor : tensor
It represents the size of tensor for generating Gaussian noise.
"""
gnoise = torch.randn(shape_of_tensor, device=device)
gnoise -= torch.min(gnoise)
gnoise /= torch.max(gnoise)
gnoise = self.eps * ((1 - 9) * gnoise + 9)
return gnoise
def add_noise(self, tensor):
gnoise = self._get_gauss_noise(tensor.size(), device=tensor.device)
gnoise = gnoise
tensor += gnoise
return tensor
def mean_pooling(self, token_embeddings, attention_mask):
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
mean = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# mean = self.add_noise(mean)
return mean
class BertModel(BertPreTrainedModel):
config_class = BertConfig
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output, attention_mask) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
class BertForSequenceClassification(BertPreTrainedModel):
config_class = BertConfig
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.bert = BertModel(config)
classifier_dropout = (
config.classifier_dropout
if config.classifier_dropout is not None
else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = nn.MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)