piiceetah-unmask-qwen2.5-0.5b / modeling_unmasking_qwen.py
yangwang825's picture
End of training
692aeeb verified
# Unmasking Qwen Token Classification Models
# Automatically generated file for model use with trust_remote_code=True
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import TokenClassifierOutput
from typing import Optional, Tuple, Union, List, Dict, Callable
import torch
import torch.nn as nn
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.bert import (
BertConfig, BertModel, BertPreTrainedModel
)
from transformers.models.roberta import (
RobertaConfig, RobertaModel, RobertaPreTrainedModel
)
from transformers.models.deberta_v2 import (
DebertaV2Config, DebertaV2Model, DebertaV2PreTrainedModel
)
from transformers.models.modernbert.modeling_modernbert import (
ModernBertConfig, ModernBertModel, ModernBertPreTrainedModel, ModernBertPredictionHead
)
from transformers import Qwen2Config
from transformers.modeling_outputs import TokenClassifierOutput, BaseModelOutputWithPast
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.processing_utils import Unpack
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2PreTrainedModel,
Qwen2Model,
SlidingWindowCache,
StaticCache
)
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3PreTrainedModel,
Qwen3Config,
Qwen3Model,
Qwen3RMSNorm,
Qwen3DecoderLayer,
Qwen3Attention,
BaseModelOutputWithPast,
TokenClassifierOutput,
Cache,
FlashAttentionKwargs,
Unpack,
Qwen3RotaryEmbedding,
Qwen3MLP,
apply_rotary_pos_emb,
can_return_tuple,
eager_attention_forward
)
def fixed_cross_entropy(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
**kwargs,
) -> torch.Tensor:
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
if reduction == "sum":
if not isinstance(num_items_in_batch, torch.Tensor):
num_items_in_batch = torch.tensor(num_items_in_batch, device=loss.device, dtype=loss.dtype)
elif num_items_in_batch.device != loss.device:
num_items_in_batch = num_items_in_batch.to(loss.device)
loss = loss / num_items_in_batch
return loss
class BertForTokenClassification(BertPreTrainedModel):
def __init__(self, config: BertConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config, add_pooling_layer=False)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
num_items_in_batch: Optional[torch.Tensor] = None,
ignore_index: int = -100,
**kwargs,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
logits = logits.view(-1, self.num_labels)
labels = labels.view(-1).to(logits.device)
logits = logits.float()
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class CRF(nn.Module):
"""條件隨機場(CRF)層,基於更穩定的實現"""
def __init__(self, num_labels: int):
super().__init__()
self.num_labels = num_labels
# 轉移矩陣和起始/結束轉移參數
self.start_transitions = nn.Parameter(torch.empty(num_labels))
self.end_transitions = nn.Parameter(torch.empty(num_labels))
self.transitions = nn.Parameter(torch.empty(num_labels, num_labels))
# 用均勻分布初始化參數
nn.init.uniform_(self.start_transitions, -0.1, 0.1)
nn.init.uniform_(self.end_transitions, -0.1, 0.1)
nn.init.uniform_(self.transitions, -0.1, 0.1)
def _compute_log_denominator(self, features: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""計算配分函數的對數(log of the partition function)"""
seq_len, batch_size, _ = features.shape
# 初始化得分為起始轉移得分 + 第一個時間步的特征
log_score = self.start_transitions + features[0] # [batch_size, num_labels]
# 逐時間步計算得分
for i in range(1, seq_len):
# 計算所有可能的轉移得分:前一時間步得分 + 轉移得分 + 當前時間步特征
# [batch_size, num_labels, 1] + [num_labels, num_labels] + [batch_size, 1, num_labels]
# -> [batch_size, num_labels, num_labels]
next_score = (
log_score.unsqueeze(2) + # [batch_size, num_labels, 1]
self.transitions + # [num_labels, num_labels]
features[i].unsqueeze(1) # [batch_size, 1, num_labels]
)
# 對所有可能的前一個標籤取logsumexp
next_score = torch.logsumexp(next_score, dim=1)
# 根據mask更新得分
log_score = torch.where(mask[i].unsqueeze(1), next_score, log_score)
# 加上到結束標籤的轉移得分
log_score += self.end_transitions
# 對所有可能的最終標籤取logsumexp
return torch.logsumexp(log_score, dim=1)
def _compute_log_numerator(self, features: torch.Tensor, labels: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""計算給定標籤序列的得分"""
seq_len, batch_size, _ = features.shape
# 初始化得分
score = self.start_transitions[labels[0]] + features[0, torch.arange(batch_size), labels[0]]
# 逐時間步累加得分
for i in range(1, seq_len):
# 計算轉移得分和發射得分
score += (
self.transitions[labels[i-1], labels[i]] + # 轉移得分
features[i, torch.arange(batch_size), labels[i]] # 發射得分
) * mask[i] # 只對有效位置計算
# 計算序列長度(減去1是因為索引從0開始)
seq_lens = mask.sum(dim=0) - 1
# 獲取每個序列的最後一個有效標籤
last_tags = labels[seq_lens.long(), torch.arange(batch_size)]
# 加上到結束標籤的轉移得分
score += self.end_transitions[last_tags]
return score
def forward(self, emissions: torch.Tensor, tags: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
計算CRF負對數似然損失
參數:
emissions: (seq_len, batch_size, num_labels) 發射得分
tags: (seq_len, batch_size) 真實標籤
mask: (seq_len, batch_size) 用於處理變長序列的遮罩
返回:
CRF負對數似然損失
"""
# 計算分子(numerator)和分母(denominator)的對數
log_numerator = self._compute_log_numerator(emissions, tags, mask)
log_denominator = self._compute_log_denominator(emissions, mask)
# 損失是分母減分子
loss = torch.mean(log_denominator - log_numerator)
return loss
def _viterbi_decode(self, features: torch.Tensor, mask: torch.Tensor) -> List[List[int]]:
"""Viterbi算法解碼,找出最可能的標籤序列"""
seq_len, batch_size, _ = features.shape
# 初始化Viterbi變量
log_score = self.start_transitions + features[0] # [batch_size, num_labels]
backpointers = torch.zeros((seq_len, batch_size, self.num_labels), dtype=torch.long, device=features.device)
# 逐時間步填充
for i in range(1, seq_len):
# 計算所有可能的轉移得分
next_score = log_score.unsqueeze(2) + self.transitions + features[i].unsqueeze(1)
# 找出每個當前標籤的最佳前一個標籤
next_score, indices = next_score.max(dim=1)
# 記錄回溯指針
backpointers[i] = indices
# 根據mask更新得分
log_score = torch.where(mask[i].unsqueeze(1), next_score, log_score)
# 加上到結束標籤的轉移得分
log_score += self.end_transitions
# 找出每個序列的最後一個標籤
seq_lens = mask.sum(dim=0).long() - 1 # 序列長度
# 回溯獲取最佳路徑
best_paths = []
for seq_idx in range(batch_size):
# 找出得分最高的最終標籤
best_label = torch.argmax(log_score[seq_idx]).item()
best_path = [best_label]
# 從後向前回溯
for i in range(seq_lens[seq_idx], 0, -1):
best_label = backpointers[i, seq_idx, best_label].item()
best_path.insert(0, best_label)
best_paths.append(best_path)
return best_paths
def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> List[List[int]]:
"""使用Viterbi解碼找出最可能的標籤序列"""
# 確保mask是bool類型
if mask.dtype != torch.bool:
mask = mask.bool()
with torch.no_grad():
return self._viterbi_decode(emissions, mask)
class BertCRFForTokenClassification(BertPreTrainedModel):
"""BERT模型與CRF層結合用於token分類"""
def __init__(self, config: BertConfig):
super().__init__(config)
self.num_labels = config.num_labels
# BERT層
self.bert = BertModel(config, add_pooling_layer=False)
# Dropout和分類器
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# CRF層
self.crf = CRF(config.num_labels)
# 初始化權重
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
ignore_index: int = -100,
**kwargs,
) -> Union[Tuple[torch.Tensor], Dict[str, torch.Tensor]]:
"""
使用CRF進行序列標注的前向傳播
參數:
labels: 標籤序列,用於計算損失
ignore_index: 忽略的標籤值,通常為-100
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# BERT前向傳播
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
# 獲取發射分數
logits = self.classifier(sequence_output) # [batch_size, seq_len, num_labels]
loss = None
if labels is not None:
# 準備CRF所需的輸入格式
# 交換維度:[batch_size, seq_len, num_labels] -> [seq_len, batch_size, num_labels]
emissions = logits.transpose(0, 1)
# 交換維度:[batch_size, seq_len] -> [seq_len, batch_size]
if attention_mask is not None:
attention_mask_t = attention_mask.transpose(0, 1).bool()
else:
attention_mask_t = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device)
# 處理ignore_index
if ignore_index is not None:
labels_mask = (labels != ignore_index)
attention_mask_t = attention_mask_t & labels_mask.transpose(0, 1)
# 創建一個不包含ignore_index的標籤tensor
crf_labels = labels.clone()
crf_labels[~labels_mask] = 0 # 將ignore的位置臨時設為0,避免其影響CRF計算
crf_labels_t = crf_labels.transpose(0, 1)
else:
crf_labels_t = labels.transpose(0, 1)
# 計算CRF損失
loss = self.crf(
emissions=emissions,
tags=crf_labels_t,
mask=attention_mask_t
)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def decode(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> List[List[int]]:
"""
解碼最可能的標籤序列
"""
# 不計算梯度
with torch.no_grad():
# BERT前向傳播
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True,
**kwargs,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
# 獲取發射分數
logits = self.classifier(sequence_output) # [batch_size, seq_len, num_labels]
# 交換維度:[batch_size, seq_len, num_labels] -> [seq_len, batch_size, num_labels]
emissions = logits.transpose(0, 1)
# 準備遮罩
if attention_mask is not None:
mask = attention_mask.transpose(0, 1).bool()
else:
mask = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device)
# 使用Viterbi算法解碼
best_tags = self.crf.decode(emissions, mask)
return best_tags
class RobertaForTokenClassification(RobertaPreTrainedModel):
def __init__(self, config: RobertaConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.roberta = RobertaModel(config, add_pooling_layer=False)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
num_items_in_batch: Optional[torch.Tensor] = None,
ignore_index: int = -100,
**kwargs,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
- 0 corresponds to a *sentence A* token,
- 1 corresponds to a *sentence B* token.
This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
>= 2. All the value in this tensor should be always < type_vocab_size.
[What are token type IDs?](../glossary#token-type-ids)
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.view(-1, self.num_labels)
labels = labels.view(-1).to(logits.device)
logits = logits.float()
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
def __init__(self, config: DebertaV2Config):
super().__init__(config)
self.num_labels = config.num_labels
self.deberta = DebertaV2Model(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
num_items_in_batch: Optional[torch.Tensor] = None,
ignore_index: int = -100,
**kwargs,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.deberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.view(-1, self.num_labels)
labels = labels.view(-1).to(logits.device)
logits = logits.float()
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
class ModernBertForTokenClassification(ModernBertPreTrainedModel):
def __init__(self, config: ModernBertConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.model = ModernBertModel(config)
self.head = ModernBertPredictionHead(config)
self.drop = torch.nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
num_items_in_batch: Optional[torch.Tensor] = None,
ignore_index: int = -100,
**kwargs,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
far-away tokens in the local attention layers when not using Flash Attention.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
max_seqlen (`int`, *optional*):
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
batch_size (`int`, *optional*):
Batch size of the input sequences. Used to pad the output tensors.
seq_len (`int`, *optional*):
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self._maybe_set_compile()
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = outputs[0]
last_hidden_state = self.head(last_hidden_state)
last_hidden_state = self.drop(last_hidden_state)
logits = self.classifier(last_hidden_state)
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.view(-1, self.num_labels)
labels = labels.view(-1).to(logits.device)
logits = logits.float()
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class UnmaskingQwen3Attention(Qwen3Attention):
"""Multi-headed attention without causal mask for bidirectional attention"""
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# Use eager attention as default
attention_interface: Callable = eager_attention_forward
# Remove causal mask by setting attention_mask to None or creating a non-causal mask
# For bidirectional attention, we don't want any masking except padding
if attention_mask is not None and 0.0 in attention_mask:
# Keep only padding mask if it exists, remove causal part
# This allows tokens to attend to future tokens
pass
else:
# If there's no padding, we can set attention_mask to None for full attention
attention_mask = None
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class UnmaskingQwen3DecoderLayer(Qwen3DecoderLayer):
def __init__(self, config: Qwen3Config, layer_idx: int):
super(Qwen3DecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
self.self_attn = UnmaskingQwen3Attention(config=config, layer_idx=layer_idx)
self.mlp = Qwen3MLP(config)
self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
class UnmaskingQwen3Model(Qwen3Model):
def __init__(self, config: Qwen3Config):
super(Qwen3PreTrainedModel, self).__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[UnmaskingQwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen3RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
# Override the causal mask creation to create a non-causal mask
# This allows bidirectional attention
if attention_mask is None:
# If no attention mask is provided, return None to allow full attention
return None
# If attention_mask is provided, it's likely for padding
# Convert it to the right format but without the causal constraint
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
batch_size = input_tensor.shape[0]
sequence_length = input_tensor.shape[1]
if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2:
# Convert 2D padding mask to 4D attention mask
expanded_attn_mask = attention_mask[:, None, None, :]
expanded_attn_mask = expanded_attn_mask.to(dtype=dtype)
expanded_attn_mask = (1.0 - expanded_attn_mask) * min_dtype
return expanded_attn_mask
# If it's already 4D, return as is
return attention_mask
class UnmaskingQwen3ForTokenClassification(Qwen3PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = UnmaskingQwen3Model(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> TokenClassifierOutput:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
outputs: BaseModelOutputWithPast = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs.last_hidden_state
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class UnmaskingQwen2Model(Qwen2Model):
"""
UnmaskingQwen2Model is a modified version of Qwen2Model that removes the causal mask,
allowing bidirectional attention similar to BERT-like models.
"""
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
"""
Override the causal mask creation to create a non-causal (bidirectional) mask.
This allows each token to attend to all tokens in the sequence.
"""
# For flash attention, just return None or the padding mask
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For flex attention, keep the same behavior but without causality
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
# We don't convert to causal mask here
return attention_mask
return attention_mask
# For other attention implementations, create a non-causal mask
batch_size = input_tensor.shape[0]
sequence_length = input_tensor.shape[1]
dtype = input_tensor.dtype
# For SlidingWindowCache or StaticCache
if isinstance(past_key_values, (SlidingWindowCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
# For DynamicCache or no cache
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_key_values.get_seq_length() + sequence_length + 1
if past_key_values is not None
else sequence_length
)
# Create a non-causal mask (all zeros, allowing full attention)
# Instead of using min_dtype to mask out future tokens, we use zeros to allow attention to all positions
non_causal_mask = torch.zeros(
(batch_size, 1, sequence_length, target_length),
dtype=dtype,
device=input_tensor.device,
)
# If there's a padding attention mask, apply it
if attention_mask is not None:
if attention_mask.dim() == 2:
# Convert 2D attention mask to 4D
expanded_mask = attention_mask[:, None, None, :].expand(
batch_size, 1, sequence_length, attention_mask.shape[-1]
).to(non_causal_mask.device)
# Apply padding mask (0 for tokens to attend to, large negative for padded positions)
min_dtype = torch.finfo(dtype).min
padding_mask = expanded_mask == 0
non_causal_mask = non_causal_mask.masked_fill(padding_mask, min_dtype)
elif attention_mask.dim() == 4:
# If already 4D, use as is
non_causal_mask = attention_mask
return non_causal_mask
class UnmaskingQwen2ForTokenClassification(Qwen2PreTrainedModel):
"""
Qwen2 model with a token classification head on top, but with bidirectional attention.
This is achieved by using the UnmaskingQwen2Model which removes the causal mask.
"""
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
# Use the UnmaskingQwen2Model instead of the standard Qwen2Model
self.model = UnmaskingQwen2Model(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> TokenClassifierOutput:
"""
Forward pass for token classification with bidirectional attention.
Args:
input_ids: Input token IDs
attention_mask: Attention mask
position_ids: Position IDs
past_key_values: Past key values for efficient generation
inputs_embeds: Pre-computed input embeddings
labels: Token classification labels
use_cache: Whether to use cache for efficient generation
output_attentions: Whether to output attention weights
output_hidden_states: Whether to output hidden states
flash_attn_kwargs: Additional arguments for flash attention
Returns:
TokenClassifierOutput with loss, logits, and optional hidden states and attentions
"""
outputs: BaseModelOutputWithPast = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
**flash_attn_kwargs,
)
sequence_output = outputs.last_hidden_state
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.config)
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)