from typing import Optional, Tuple, Union, List, Dict import torch import torch.nn as nn from transformers.modeling_outputs import TokenClassifierOutput from transformers.models.bert import ( BertConfig, BertModel, BertPreTrainedModel ) from transformers.models.roberta import ( RobertaConfig, RobertaModel, RobertaPreTrainedModel ) from transformers.models.deberta_v2 import ( DebertaV2Config, DebertaV2Model, DebertaV2PreTrainedModel ) from transformers.models.modernbert.modeling_modernbert import ( ModernBertConfig, ModernBertModel, ModernBertPreTrainedModel, ModernBertPredictionHead ) def fixed_cross_entropy( source: torch.Tensor, target: torch.Tensor, num_items_in_batch: Optional[int] = None, ignore_index: int = -100, **kwargs, ) -> torch.Tensor: reduction = "sum" if num_items_in_batch is not None else "mean" loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) if reduction == "sum": if not isinstance(num_items_in_batch, torch.Tensor): num_items_in_batch = torch.tensor(num_items_in_batch, device=loss.device, dtype=loss.dtype) elif num_items_in_batch.device != loss.device: num_items_in_batch = num_items_in_batch.to(loss.device) loss = loss / num_items_in_batch return loss class BertForTokenClassification(BertPreTrainedModel): def __init__(self, config: BertConfig): super().__init__(config) self.num_labels = config.num_labels self.bert = BertModel(config, add_pooling_layer=False) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, num_items_in_batch: Optional[torch.Tensor] = None, ignore_index: int = -100, **kwargs, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) loss = None if labels is not None: logits = logits.view(-1, self.num_labels) labels = labels.view(-1).to(logits.device) logits = logits.float() loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class CRF(nn.Module): """條件隨機場(CRF)層,基於更穩定的實現""" def __init__(self, num_labels: int): super().__init__() self.num_labels = num_labels # 轉移矩陣和起始/結束轉移參數 self.start_transitions = nn.Parameter(torch.empty(num_labels)) self.end_transitions = nn.Parameter(torch.empty(num_labels)) self.transitions = nn.Parameter(torch.empty(num_labels, num_labels)) # 用均勻分布初始化參數 nn.init.uniform_(self.start_transitions, -0.1, 0.1) nn.init.uniform_(self.end_transitions, -0.1, 0.1) nn.init.uniform_(self.transitions, -0.1, 0.1) def _compute_log_denominator(self, features: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """計算配分函數的對數(log of the partition function)""" seq_len, batch_size, _ = features.shape # 初始化得分為起始轉移得分 + 第一個時間步的特征 log_score = self.start_transitions + features[0] # [batch_size, num_labels] # 逐時間步計算得分 for i in range(1, seq_len): # 計算所有可能的轉移得分:前一時間步得分 + 轉移得分 + 當前時間步特征 # [batch_size, num_labels, 1] + [num_labels, num_labels] + [batch_size, 1, num_labels] # -> [batch_size, num_labels, num_labels] next_score = ( log_score.unsqueeze(2) + # [batch_size, num_labels, 1] self.transitions + # [num_labels, num_labels] features[i].unsqueeze(1) # [batch_size, 1, num_labels] ) # 對所有可能的前一個標籤取logsumexp next_score = torch.logsumexp(next_score, dim=1) # 根據mask更新得分 log_score = torch.where(mask[i].unsqueeze(1), next_score, log_score) # 加上到結束標籤的轉移得分 log_score += self.end_transitions # 對所有可能的最終標籤取logsumexp return torch.logsumexp(log_score, dim=1) def _compute_log_numerator(self, features: torch.Tensor, labels: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """計算給定標籤序列的得分""" seq_len, batch_size, _ = features.shape # 初始化得分 score = self.start_transitions[labels[0]] + features[0, torch.arange(batch_size), labels[0]] # 逐時間步累加得分 for i in range(1, seq_len): # 計算轉移得分和發射得分 score += ( self.transitions[labels[i-1], labels[i]] + # 轉移得分 features[i, torch.arange(batch_size), labels[i]] # 發射得分 ) * mask[i] # 只對有效位置計算 # 計算序列長度(減去1是因為索引從0開始) seq_lens = mask.sum(dim=0) - 1 # 獲取每個序列的最後一個有效標籤 last_tags = labels[seq_lens.long(), torch.arange(batch_size)] # 加上到結束標籤的轉移得分 score += self.end_transitions[last_tags] return score def forward(self, emissions: torch.Tensor, tags: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ 計算CRF負對數似然損失 參數: emissions: (seq_len, batch_size, num_labels) 發射得分 tags: (seq_len, batch_size) 真實標籤 mask: (seq_len, batch_size) 用於處理變長序列的遮罩 返回: CRF負對數似然損失 """ # 計算分子(numerator)和分母(denominator)的對數 log_numerator = self._compute_log_numerator(emissions, tags, mask) log_denominator = self._compute_log_denominator(emissions, mask) # 損失是分母減分子 loss = torch.mean(log_denominator - log_numerator) return loss def _viterbi_decode(self, features: torch.Tensor, mask: torch.Tensor) -> List[List[int]]: """Viterbi算法解碼,找出最可能的標籤序列""" seq_len, batch_size, _ = features.shape # 初始化Viterbi變量 log_score = self.start_transitions + features[0] # [batch_size, num_labels] backpointers = torch.zeros((seq_len, batch_size, self.num_labels), dtype=torch.long, device=features.device) # 逐時間步填充 for i in range(1, seq_len): # 計算所有可能的轉移得分 next_score = log_score.unsqueeze(2) + self.transitions + features[i].unsqueeze(1) # 找出每個當前標籤的最佳前一個標籤 next_score, indices = next_score.max(dim=1) # 記錄回溯指針 backpointers[i] = indices # 根據mask更新得分 log_score = torch.where(mask[i].unsqueeze(1), next_score, log_score) # 加上到結束標籤的轉移得分 log_score += self.end_transitions # 找出每個序列的最後一個標籤 seq_lens = mask.sum(dim=0).long() - 1 # 序列長度 # 回溯獲取最佳路徑 best_paths = [] for seq_idx in range(batch_size): # 找出得分最高的最終標籤 best_label = torch.argmax(log_score[seq_idx]).item() best_path = [best_label] # 從後向前回溯 for i in range(seq_lens[seq_idx], 0, -1): best_label = backpointers[i, seq_idx, best_label].item() best_path.insert(0, best_label) best_paths.append(best_path) return best_paths def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> List[List[int]]: """使用Viterbi解碼找出最可能的標籤序列""" # 確保mask是bool類型 if mask.dtype != torch.bool: mask = mask.bool() with torch.no_grad(): return self._viterbi_decode(emissions, mask) class BertCRFForTokenClassification(BertPreTrainedModel): """BERT模型與CRF層結合用於token分類""" def __init__(self, config: BertConfig): super().__init__(config) self.num_labels = config.num_labels # BERT層 self.bert = BertModel(config, add_pooling_layer=False) # Dropout和分類器 classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # CRF層 self.crf = CRF(config.num_labels) # 初始化權重 self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ignore_index: int = -100, **kwargs, ) -> Union[Tuple[torch.Tensor], Dict[str, torch.Tensor]]: """ 使用CRF進行序列標注的前向傳播 參數: labels: 標籤序列,用於計算損失 ignore_index: 忽略的標籤值,通常為-100 """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # BERT前向傳播 outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) # 獲取發射分數 logits = self.classifier(sequence_output) # [batch_size, seq_len, num_labels] loss = None if labels is not None: # 準備CRF所需的輸入格式 # 交換維度:[batch_size, seq_len, num_labels] -> [seq_len, batch_size, num_labels] emissions = logits.transpose(0, 1) # 交換維度:[batch_size, seq_len] -> [seq_len, batch_size] if attention_mask is not None: attention_mask_t = attention_mask.transpose(0, 1).bool() else: attention_mask_t = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device) # 處理ignore_index if ignore_index is not None: labels_mask = (labels != ignore_index) attention_mask_t = attention_mask_t & labels_mask.transpose(0, 1) # 創建一個不包含ignore_index的標籤tensor crf_labels = labels.clone() crf_labels[~labels_mask] = 0 # 將ignore的位置臨時設為0,避免其影響CRF計算 crf_labels_t = crf_labels.transpose(0, 1) else: crf_labels_t = labels.transpose(0, 1) # 計算CRF損失 loss = self.crf( emissions=emissions, tags=crf_labels_t, mask=attention_mask_t ) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def decode( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, **kwargs, ) -> List[List[int]]: """ 解碼最可能的標籤序列 """ # 不計算梯度 with torch.no_grad(): # BERT前向傳播 outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True, **kwargs, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) # 獲取發射分數 logits = self.classifier(sequence_output) # [batch_size, seq_len, num_labels] # 交換維度:[batch_size, seq_len, num_labels] -> [seq_len, batch_size, num_labels] emissions = logits.transpose(0, 1) # 準備遮罩 if attention_mask is not None: mask = attention_mask.transpose(0, 1).bool() else: mask = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device) # 使用Viterbi算法解碼 best_tags = self.crf.decode(emissions, mask) return best_tags class RobertaForTokenClassification(RobertaPreTrainedModel): def __init__(self, config: RobertaConfig): super().__init__(config) self.num_labels = config.num_labels self.roberta = RobertaModel(config, add_pooling_layer=False) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, num_items_in_batch: Optional[torch.Tensor] = None, ignore_index: int = -100, **kwargs, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: - 0 corresponds to a *sentence A* token, - 1 corresponds to a *sentence B* token. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value >= 2. All the value in this tensor should be always < type_vocab_size. [What are token type IDs?](../glossary#token-type-ids) labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.view(-1, self.num_labels) labels = labels.view(-1).to(logits.device) logits = logits.float() loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): def __init__(self, config: DebertaV2Config): super().__init__(config) self.num_labels = config.num_labels self.deberta = DebertaV2Model(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, num_items_in_batch: Optional[torch.Tensor] = None, ignore_index: int = -100, **kwargs, ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.deberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.view(-1, self.num_labels) labels = labels.view(-1).to(logits.device) logits = logits.float() loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions ) class ModernBertForTokenClassification(ModernBertPreTrainedModel): def __init__(self, config: ModernBertConfig): super().__init__(config) self.num_labels = config.num_labels self.model = ModernBertModel(config) self.head = ModernBertPredictionHead(config) self.drop = torch.nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, num_items_in_batch: Optional[torch.Tensor] = None, ignore_index: int = -100, **kwargs, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers perform global attention, while the rest perform local attention. This mask is used to avoid attending to far-away tokens in the local attention layers when not using Flash Attention. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): Indices of the non-padding tokens in the input sequence. Used for unpadding the output. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. max_seqlen (`int`, *optional*): Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. batch_size (`int`, *optional*): Batch size of the input sequences. Used to pad the output tensors. seq_len (`int`, *optional*): Sequence length of the input sequences including padding tokens. Used to pad the output tensors. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = outputs[0] last_hidden_state = self.head(last_hidden_state) last_hidden_state = self.drop(last_hidden_state) logits = self.classifier(last_hidden_state) loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.view(-1, self.num_labels) labels = labels.view(-1).to(logits.device) logits = logits.float() loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )