|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel |
|
from transformers.modeling_outputs import TokenClassifierOutput |
|
|
|
from typing import Optional, Tuple, Union, List, Dict, Callable |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from transformers.modeling_outputs import TokenClassifierOutput |
|
from transformers.models.bert import ( |
|
BertConfig, BertModel, BertPreTrainedModel |
|
) |
|
from transformers.models.roberta import ( |
|
RobertaConfig, RobertaModel, RobertaPreTrainedModel |
|
) |
|
from transformers.models.deberta_v2 import ( |
|
DebertaV2Config, DebertaV2Model, DebertaV2PreTrainedModel |
|
) |
|
from transformers.models.modernbert.modeling_modernbert import ( |
|
ModernBertConfig, ModernBertModel, ModernBertPreTrainedModel, ModernBertPredictionHead |
|
) |
|
from transformers import Qwen2Config |
|
from transformers.modeling_outputs import TokenClassifierOutput, BaseModelOutputWithPast |
|
from transformers.cache_utils import Cache |
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
from transformers.processing_utils import Unpack |
|
from transformers.models.qwen2.modeling_qwen2 import ( |
|
Qwen2PreTrainedModel, |
|
Qwen2Model, |
|
SlidingWindowCache, |
|
StaticCache |
|
) |
|
|
|
from transformers.models.qwen3.modeling_qwen3 import ( |
|
Qwen3PreTrainedModel, |
|
Qwen3Config, |
|
Qwen3Model, |
|
Qwen3RMSNorm, |
|
Qwen3DecoderLayer, |
|
Qwen3Attention, |
|
BaseModelOutputWithPast, |
|
TokenClassifierOutput, |
|
Cache, |
|
FlashAttentionKwargs, |
|
Unpack, |
|
Qwen3RotaryEmbedding, |
|
Qwen3MLP, |
|
apply_rotary_pos_emb, |
|
can_return_tuple, |
|
eager_attention_forward |
|
) |
|
|
|
|
|
def fixed_cross_entropy( |
|
source: torch.Tensor, |
|
target: torch.Tensor, |
|
num_items_in_batch: Optional[int] = None, |
|
ignore_index: int = -100, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
reduction = "sum" if num_items_in_batch is not None else "mean" |
|
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) |
|
if reduction == "sum": |
|
if not isinstance(num_items_in_batch, torch.Tensor): |
|
num_items_in_batch = torch.tensor(num_items_in_batch, device=loss.device, dtype=loss.dtype) |
|
elif num_items_in_batch.device != loss.device: |
|
num_items_in_batch = num_items_in_batch.to(loss.device) |
|
loss = loss / num_items_in_batch |
|
return loss |
|
|
|
|
|
class BertForTokenClassification(BertPreTrainedModel): |
|
|
|
def __init__(self, config: BertConfig): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
classifier_dropout = ( |
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
) |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
num_items_in_batch: Optional[torch.Tensor] = None, |
|
ignore_index: int = -100, |
|
**kwargs, |
|
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
sequence_output = self.dropout(sequence_output) |
|
logits = self.classifier(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
logits = logits.view(-1, self.num_labels) |
|
labels = labels.view(-1).to(logits.device) |
|
logits = logits.float() |
|
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class CRF(nn.Module): |
|
"""條件隨機場(CRF)層,基於更穩定的實現""" |
|
|
|
def __init__(self, num_labels: int): |
|
super().__init__() |
|
self.num_labels = num_labels |
|
|
|
|
|
self.start_transitions = nn.Parameter(torch.empty(num_labels)) |
|
self.end_transitions = nn.Parameter(torch.empty(num_labels)) |
|
self.transitions = nn.Parameter(torch.empty(num_labels, num_labels)) |
|
|
|
|
|
nn.init.uniform_(self.start_transitions, -0.1, 0.1) |
|
nn.init.uniform_(self.end_transitions, -0.1, 0.1) |
|
nn.init.uniform_(self.transitions, -0.1, 0.1) |
|
|
|
def _compute_log_denominator(self, features: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
|
"""計算配分函數的對數(log of the partition function)""" |
|
seq_len, batch_size, _ = features.shape |
|
|
|
|
|
log_score = self.start_transitions + features[0] |
|
|
|
|
|
for i in range(1, seq_len): |
|
|
|
|
|
|
|
next_score = ( |
|
log_score.unsqueeze(2) + |
|
self.transitions + |
|
features[i].unsqueeze(1) |
|
) |
|
|
|
|
|
next_score = torch.logsumexp(next_score, dim=1) |
|
|
|
|
|
log_score = torch.where(mask[i].unsqueeze(1), next_score, log_score) |
|
|
|
|
|
log_score += self.end_transitions |
|
|
|
|
|
return torch.logsumexp(log_score, dim=1) |
|
|
|
def _compute_log_numerator(self, features: torch.Tensor, labels: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
|
"""計算給定標籤序列的得分""" |
|
seq_len, batch_size, _ = features.shape |
|
|
|
|
|
score = self.start_transitions[labels[0]] + features[0, torch.arange(batch_size), labels[0]] |
|
|
|
|
|
for i in range(1, seq_len): |
|
|
|
score += ( |
|
self.transitions[labels[i-1], labels[i]] + |
|
features[i, torch.arange(batch_size), labels[i]] |
|
) * mask[i] |
|
|
|
|
|
seq_lens = mask.sum(dim=0) - 1 |
|
|
|
|
|
last_tags = labels[seq_lens.long(), torch.arange(batch_size)] |
|
|
|
|
|
score += self.end_transitions[last_tags] |
|
|
|
return score |
|
|
|
def forward(self, emissions: torch.Tensor, tags: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
|
""" |
|
計算CRF負對數似然損失 |
|
|
|
參數: |
|
emissions: (seq_len, batch_size, num_labels) 發射得分 |
|
tags: (seq_len, batch_size) 真實標籤 |
|
mask: (seq_len, batch_size) 用於處理變長序列的遮罩 |
|
|
|
返回: |
|
CRF負對數似然損失 |
|
""" |
|
|
|
log_numerator = self._compute_log_numerator(emissions, tags, mask) |
|
log_denominator = self._compute_log_denominator(emissions, mask) |
|
|
|
|
|
loss = torch.mean(log_denominator - log_numerator) |
|
|
|
return loss |
|
|
|
def _viterbi_decode(self, features: torch.Tensor, mask: torch.Tensor) -> List[List[int]]: |
|
"""Viterbi算法解碼,找出最可能的標籤序列""" |
|
seq_len, batch_size, _ = features.shape |
|
|
|
|
|
log_score = self.start_transitions + features[0] |
|
backpointers = torch.zeros((seq_len, batch_size, self.num_labels), dtype=torch.long, device=features.device) |
|
|
|
|
|
for i in range(1, seq_len): |
|
|
|
next_score = log_score.unsqueeze(2) + self.transitions + features[i].unsqueeze(1) |
|
|
|
|
|
next_score, indices = next_score.max(dim=1) |
|
|
|
|
|
backpointers[i] = indices |
|
|
|
|
|
log_score = torch.where(mask[i].unsqueeze(1), next_score, log_score) |
|
|
|
|
|
log_score += self.end_transitions |
|
|
|
|
|
seq_lens = mask.sum(dim=0).long() - 1 |
|
|
|
|
|
best_paths = [] |
|
for seq_idx in range(batch_size): |
|
|
|
best_label = torch.argmax(log_score[seq_idx]).item() |
|
best_path = [best_label] |
|
|
|
|
|
for i in range(seq_lens[seq_idx], 0, -1): |
|
best_label = backpointers[i, seq_idx, best_label].item() |
|
best_path.insert(0, best_label) |
|
|
|
best_paths.append(best_path) |
|
|
|
return best_paths |
|
|
|
def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> List[List[int]]: |
|
"""使用Viterbi解碼找出最可能的標籤序列""" |
|
|
|
if mask.dtype != torch.bool: |
|
mask = mask.bool() |
|
|
|
with torch.no_grad(): |
|
return self._viterbi_decode(emissions, mask) |
|
|
|
|
|
class BertCRFForTokenClassification(BertPreTrainedModel): |
|
"""BERT模型與CRF層結合用於token分類""" |
|
|
|
def __init__(self, config: BertConfig): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
|
|
|
|
classifier_dropout = ( |
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
) |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.crf = CRF(config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
ignore_index: int = -100, |
|
**kwargs, |
|
) -> Union[Tuple[torch.Tensor], Dict[str, torch.Tensor]]: |
|
""" |
|
使用CRF進行序列標注的前向傳播 |
|
|
|
參數: |
|
labels: 標籤序列,用於計算損失 |
|
ignore_index: 忽略的標籤值,通常為-100 |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.bert( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
sequence_output = self.dropout(sequence_output) |
|
|
|
|
|
logits = self.classifier(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
|
|
emissions = logits.transpose(0, 1) |
|
|
|
|
|
if attention_mask is not None: |
|
attention_mask_t = attention_mask.transpose(0, 1).bool() |
|
else: |
|
attention_mask_t = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device) |
|
|
|
|
|
if ignore_index is not None: |
|
labels_mask = (labels != ignore_index) |
|
attention_mask_t = attention_mask_t & labels_mask.transpose(0, 1) |
|
|
|
|
|
crf_labels = labels.clone() |
|
crf_labels[~labels_mask] = 0 |
|
crf_labels_t = crf_labels.transpose(0, 1) |
|
else: |
|
crf_labels_t = labels.transpose(0, 1) |
|
|
|
|
|
loss = self.crf( |
|
emissions=emissions, |
|
tags=crf_labels_t, |
|
mask=attention_mask_t |
|
) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def decode( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
**kwargs, |
|
) -> List[List[int]]: |
|
""" |
|
解碼最可能的標籤序列 |
|
""" |
|
|
|
with torch.no_grad(): |
|
|
|
outputs = self.bert( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
return_dict=True, |
|
**kwargs, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
sequence_output = self.dropout(sequence_output) |
|
|
|
|
|
logits = self.classifier(sequence_output) |
|
|
|
|
|
emissions = logits.transpose(0, 1) |
|
|
|
|
|
if attention_mask is not None: |
|
mask = attention_mask.transpose(0, 1).bool() |
|
else: |
|
mask = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device) |
|
|
|
|
|
best_tags = self.crf.decode(emissions, mask) |
|
|
|
return best_tags |
|
|
|
|
|
class RobertaForTokenClassification(RobertaPreTrainedModel): |
|
|
|
def __init__(self, config: RobertaConfig): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.roberta = RobertaModel(config, add_pooling_layer=False) |
|
classifier_dropout = ( |
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
) |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
num_items_in_batch: Optional[torch.Tensor] = None, |
|
ignore_index: int = -100, |
|
**kwargs, |
|
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: |
|
r""" |
|
token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: |
|
|
|
- 0 corresponds to a *sentence A* token, |
|
- 1 corresponds to a *sentence B* token. |
|
This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value |
|
>= 2. All the value in this tensor should be always < type_vocab_size. |
|
|
|
[What are token type IDs?](../glossary#token-type-ids) |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.roberta( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
sequence_output = self.dropout(sequence_output) |
|
logits = self.classifier(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
logits = logits.view(-1, self.num_labels) |
|
labels = labels.view(-1).to(logits.device) |
|
logits = logits.float() |
|
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): |
|
|
|
def __init__(self, config: DebertaV2Config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.deberta = DebertaV2Model(config) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
num_items_in_batch: Optional[torch.Tensor] = None, |
|
ignore_index: int = -100, |
|
**kwargs, |
|
) -> Union[Tuple, TokenClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.deberta( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
sequence_output = self.dropout(sequence_output) |
|
logits = self.classifier(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
logits = logits.view(-1, self.num_labels) |
|
labels = labels.view(-1).to(logits.device) |
|
logits = logits.float() |
|
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions |
|
) |
|
|
|
|
|
class ModernBertForTokenClassification(ModernBertPreTrainedModel): |
|
|
|
def __init__(self, config: ModernBertConfig): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.model = ModernBertModel(config) |
|
self.head = ModernBertPredictionHead(config) |
|
self.drop = torch.nn.Dropout(config.classifier_dropout) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
sliding_window_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
indices: Optional[torch.Tensor] = None, |
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
max_seqlen: Optional[int] = None, |
|
batch_size: Optional[int] = None, |
|
seq_len: Optional[int] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
num_items_in_batch: Optional[torch.Tensor] = None, |
|
ignore_index: int = -100, |
|
**kwargs, |
|
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: |
|
r""" |
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers |
|
perform global attention, while the rest perform local attention. This mask is used to avoid attending to |
|
far-away tokens in the local attention layers when not using Flash Attention. |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. |
|
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): |
|
Indices of the non-padding tokens in the input sequence. Used for unpadding the output. |
|
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): |
|
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. |
|
max_seqlen (`int`, *optional*): |
|
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. |
|
batch_size (`int`, *optional*): |
|
Batch size of the input sequences. Used to pad the output tensors. |
|
seq_len (`int`, *optional*): |
|
Sequence length of the input sequences including padding tokens. Used to pad the output tensors. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
self._maybe_set_compile() |
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
sliding_window_mask=sliding_window_mask, |
|
position_ids=position_ids, |
|
inputs_embeds=inputs_embeds, |
|
indices=indices, |
|
cu_seqlens=cu_seqlens, |
|
max_seqlen=max_seqlen, |
|
batch_size=batch_size, |
|
seq_len=seq_len, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
last_hidden_state = outputs[0] |
|
|
|
last_hidden_state = self.head(last_hidden_state) |
|
last_hidden_state = self.drop(last_hidden_state) |
|
logits = self.classifier(last_hidden_state) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
logits = logits.view(-1, self.num_labels) |
|
labels = labels.view(-1).to(logits.device) |
|
logits = logits.float() |
|
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class UnmaskingQwen3Attention(Qwen3Attention): |
|
"""Multi-headed attention without causal mask for bidirectional attention""" |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor], |
|
attention_mask: Optional[torch.Tensor], |
|
past_key_value: Optional[Cache] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
input_shape = hidden_states.shape[:-1] |
|
hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
|
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) |
|
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) |
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
|
cos, sin = position_embeddings |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
if past_key_value is not None: |
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
|
|
attention_interface: Callable = eager_attention_forward |
|
|
|
|
|
|
|
if attention_mask is not None and 0.0 in attention_mask: |
|
|
|
|
|
pass |
|
else: |
|
|
|
attention_mask = None |
|
|
|
attn_output, attn_weights = attention_interface( |
|
self, |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
dropout=0.0 if not self.training else self.attention_dropout, |
|
scaling=self.scaling, |
|
sliding_window=self.sliding_window, |
|
**kwargs, |
|
) |
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
|
attn_output = self.o_proj(attn_output) |
|
return attn_output, attn_weights |
|
|
|
|
|
class UnmaskingQwen3DecoderLayer(Qwen3DecoderLayer): |
|
|
|
def __init__(self, config: Qwen3Config, layer_idx: int): |
|
super(Qwen3DecoderLayer, self).__init__() |
|
self.hidden_size = config.hidden_size |
|
self.self_attn = UnmaskingQwen3Attention(config=config, layer_idx=layer_idx) |
|
self.mlp = Qwen3MLP(config) |
|
self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
class UnmaskingQwen3Model(Qwen3Model): |
|
|
|
def __init__(self, config: Qwen3Config): |
|
super(Qwen3PreTrainedModel, self).__init__(config) |
|
self.padding_idx = config.pad_token_id |
|
self.vocab_size = config.vocab_size |
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
self.layers = nn.ModuleList( |
|
[UnmaskingQwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
) |
|
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.rotary_emb = Qwen3RotaryEmbedding(config=config) |
|
self.gradient_checkpointing = False |
|
|
|
|
|
self.post_init() |
|
|
|
def _update_causal_mask( |
|
self, |
|
attention_mask: torch.Tensor, |
|
input_tensor: torch.Tensor, |
|
cache_position: torch.Tensor, |
|
past_key_values: Cache, |
|
output_attentions: bool = False, |
|
): |
|
|
|
|
|
if attention_mask is None: |
|
|
|
return None |
|
|
|
|
|
|
|
dtype = input_tensor.dtype |
|
min_dtype = torch.finfo(dtype).min |
|
batch_size = input_tensor.shape[0] |
|
sequence_length = input_tensor.shape[1] |
|
|
|
if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2: |
|
|
|
expanded_attn_mask = attention_mask[:, None, None, :] |
|
expanded_attn_mask = expanded_attn_mask.to(dtype=dtype) |
|
expanded_attn_mask = (1.0 - expanded_attn_mask) * min_dtype |
|
return expanded_attn_mask |
|
|
|
|
|
return attention_mask |
|
|
|
|
|
class UnmaskingQwen3ForTokenClassification(Qwen3PreTrainedModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.model = UnmaskingQwen3Model(config) |
|
if getattr(config, "classifier_dropout", None) is not None: |
|
classifier_dropout = config.classifier_dropout |
|
elif getattr(config, "hidden_dropout", None) is not None: |
|
classifier_dropout = config.hidden_dropout |
|
else: |
|
classifier_dropout = 0.1 |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.score = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
@can_return_tuple |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
) -> TokenClassifierOutput: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the token classification loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
|
|
outputs: BaseModelOutputWithPast = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
sequence_output = outputs.last_hidden_state |
|
sequence_output = self.dropout(sequence_output) |
|
logits = self.score(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class UnmaskingQwen2Model(Qwen2Model): |
|
""" |
|
UnmaskingQwen2Model is a modified version of Qwen2Model that removes the causal mask, |
|
allowing bidirectional attention similar to BERT-like models. |
|
""" |
|
|
|
def _update_causal_mask( |
|
self, |
|
attention_mask: torch.Tensor, |
|
input_tensor: torch.Tensor, |
|
cache_position: torch.Tensor, |
|
past_key_values: Cache, |
|
output_attentions: bool = False, |
|
): |
|
""" |
|
Override the causal mask creation to create a non-causal (bidirectional) mask. |
|
This allows each token to attend to all tokens in the sequence. |
|
""" |
|
|
|
if self.config._attn_implementation == "flash_attention_2": |
|
if attention_mask is not None and 0.0 in attention_mask: |
|
return attention_mask |
|
return None |
|
|
|
|
|
if self.config._attn_implementation == "flex_attention": |
|
if isinstance(attention_mask, torch.Tensor): |
|
|
|
return attention_mask |
|
return attention_mask |
|
|
|
|
|
batch_size = input_tensor.shape[0] |
|
sequence_length = input_tensor.shape[1] |
|
dtype = input_tensor.dtype |
|
|
|
|
|
if isinstance(past_key_values, (SlidingWindowCache, StaticCache)): |
|
target_length = past_key_values.get_max_cache_shape() |
|
else: |
|
|
|
target_length = ( |
|
attention_mask.shape[-1] |
|
if isinstance(attention_mask, torch.Tensor) |
|
else past_key_values.get_seq_length() + sequence_length + 1 |
|
if past_key_values is not None |
|
else sequence_length |
|
) |
|
|
|
|
|
|
|
non_causal_mask = torch.zeros( |
|
(batch_size, 1, sequence_length, target_length), |
|
dtype=dtype, |
|
device=input_tensor.device, |
|
) |
|
|
|
|
|
if attention_mask is not None: |
|
if attention_mask.dim() == 2: |
|
|
|
expanded_mask = attention_mask[:, None, None, :].expand( |
|
batch_size, 1, sequence_length, attention_mask.shape[-1] |
|
).to(non_causal_mask.device) |
|
|
|
|
|
min_dtype = torch.finfo(dtype).min |
|
padding_mask = expanded_mask == 0 |
|
non_causal_mask = non_causal_mask.masked_fill(padding_mask, min_dtype) |
|
elif attention_mask.dim() == 4: |
|
|
|
non_causal_mask = attention_mask |
|
|
|
return non_causal_mask |
|
|
|
|
|
class UnmaskingQwen2ForTokenClassification(Qwen2PreTrainedModel): |
|
""" |
|
Qwen2 model with a token classification head on top, but with bidirectional attention. |
|
This is achieved by using the UnmaskingQwen2Model which removes the causal mask. |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
|
|
self.model = UnmaskingQwen2Model(config) |
|
|
|
if getattr(config, "classifier_dropout", None) is not None: |
|
classifier_dropout = config.classifier_dropout |
|
elif getattr(config, "hidden_dropout", None) is not None: |
|
classifier_dropout = config.hidden_dropout |
|
else: |
|
classifier_dropout = 0.1 |
|
|
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.score = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
|
) -> TokenClassifierOutput: |
|
""" |
|
Forward pass for token classification with bidirectional attention. |
|
|
|
Args: |
|
input_ids: Input token IDs |
|
attention_mask: Attention mask |
|
position_ids: Position IDs |
|
past_key_values: Past key values for efficient generation |
|
inputs_embeds: Pre-computed input embeddings |
|
labels: Token classification labels |
|
use_cache: Whether to use cache for efficient generation |
|
output_attentions: Whether to output attention weights |
|
output_hidden_states: Whether to output hidden states |
|
flash_attn_kwargs: Additional arguments for flash attention |
|
|
|
Returns: |
|
TokenClassifierOutput with loss, logits, and optional hidden states and attentions |
|
""" |
|
outputs: BaseModelOutputWithPast = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
**flash_attn_kwargs, |
|
) |
|
|
|
sequence_output = outputs.last_hidden_state |
|
sequence_output = self.dropout(sequence_output) |
|
logits = self.score(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_function(logits, labels, self.config) |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|