# Unmasking Qwen Token Classification Models # Automatically generated file for model use with trust_remote_code=True import torch import torch.nn as nn from transformers import PreTrainedModel from transformers.modeling_outputs import TokenClassifierOutput from typing import Optional, Tuple, Union, List, Dict, Callable import torch import torch.nn as nn from transformers.modeling_outputs import TokenClassifierOutput from transformers.models.bert import ( BertConfig, BertModel, BertPreTrainedModel ) from transformers.models.roberta import ( RobertaConfig, RobertaModel, RobertaPreTrainedModel ) from transformers.models.deberta_v2 import ( DebertaV2Config, DebertaV2Model, DebertaV2PreTrainedModel ) from transformers.models.modernbert.modeling_modernbert import ( ModernBertConfig, ModernBertModel, ModernBertPreTrainedModel, ModernBertPredictionHead ) from transformers import Qwen2Config from transformers.modeling_outputs import TokenClassifierOutput, BaseModelOutputWithPast from transformers.cache_utils import Cache from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.processing_utils import Unpack from transformers.models.qwen2.modeling_qwen2 import ( Qwen2PreTrainedModel, Qwen2Model, SlidingWindowCache, StaticCache ) from transformers.models.qwen3.modeling_qwen3 import ( Qwen3PreTrainedModel, Qwen3Config, Qwen3Model, Qwen3RMSNorm, Qwen3DecoderLayer, Qwen3Attention, BaseModelOutputWithPast, TokenClassifierOutput, Cache, FlashAttentionKwargs, Unpack, Qwen3RotaryEmbedding, Qwen3MLP, apply_rotary_pos_emb, can_return_tuple, eager_attention_forward ) def fixed_cross_entropy( source: torch.Tensor, target: torch.Tensor, num_items_in_batch: Optional[int] = None, ignore_index: int = -100, **kwargs, ) -> torch.Tensor: reduction = "sum" if num_items_in_batch is not None else "mean" loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) if reduction == "sum": if not isinstance(num_items_in_batch, torch.Tensor): num_items_in_batch = torch.tensor(num_items_in_batch, device=loss.device, dtype=loss.dtype) elif num_items_in_batch.device != loss.device: num_items_in_batch = num_items_in_batch.to(loss.device) loss = loss / num_items_in_batch return loss class BertForTokenClassification(BertPreTrainedModel): def __init__(self, config: BertConfig): super().__init__(config) self.num_labels = config.num_labels self.bert = BertModel(config, add_pooling_layer=False) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, num_items_in_batch: Optional[torch.Tensor] = None, ignore_index: int = -100, **kwargs, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) loss = None if labels is not None: logits = logits.view(-1, self.num_labels) labels = labels.view(-1).to(logits.device) logits = logits.float() loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class CRF(nn.Module): """條件隨機場(CRF)層,基於更穩定的實現""" def __init__(self, num_labels: int): super().__init__() self.num_labels = num_labels # 轉移矩陣和起始/結束轉移參數 self.start_transitions = nn.Parameter(torch.empty(num_labels)) self.end_transitions = nn.Parameter(torch.empty(num_labels)) self.transitions = nn.Parameter(torch.empty(num_labels, num_labels)) # 用均勻分布初始化參數 nn.init.uniform_(self.start_transitions, -0.1, 0.1) nn.init.uniform_(self.end_transitions, -0.1, 0.1) nn.init.uniform_(self.transitions, -0.1, 0.1) def _compute_log_denominator(self, features: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """計算配分函數的對數(log of the partition function)""" seq_len, batch_size, _ = features.shape # 初始化得分為起始轉移得分 + 第一個時間步的特征 log_score = self.start_transitions + features[0] # [batch_size, num_labels] # 逐時間步計算得分 for i in range(1, seq_len): # 計算所有可能的轉移得分:前一時間步得分 + 轉移得分 + 當前時間步特征 # [batch_size, num_labels, 1] + [num_labels, num_labels] + [batch_size, 1, num_labels] # -> [batch_size, num_labels, num_labels] next_score = ( log_score.unsqueeze(2) + # [batch_size, num_labels, 1] self.transitions + # [num_labels, num_labels] features[i].unsqueeze(1) # [batch_size, 1, num_labels] ) # 對所有可能的前一個標籤取logsumexp next_score = torch.logsumexp(next_score, dim=1) # 根據mask更新得分 log_score = torch.where(mask[i].unsqueeze(1), next_score, log_score) # 加上到結束標籤的轉移得分 log_score += self.end_transitions # 對所有可能的最終標籤取logsumexp return torch.logsumexp(log_score, dim=1) def _compute_log_numerator(self, features: torch.Tensor, labels: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """計算給定標籤序列的得分""" seq_len, batch_size, _ = features.shape # 初始化得分 score = self.start_transitions[labels[0]] + features[0, torch.arange(batch_size), labels[0]] # 逐時間步累加得分 for i in range(1, seq_len): # 計算轉移得分和發射得分 score += ( self.transitions[labels[i-1], labels[i]] + # 轉移得分 features[i, torch.arange(batch_size), labels[i]] # 發射得分 ) * mask[i] # 只對有效位置計算 # 計算序列長度(減去1是因為索引從0開始) seq_lens = mask.sum(dim=0) - 1 # 獲取每個序列的最後一個有效標籤 last_tags = labels[seq_lens.long(), torch.arange(batch_size)] # 加上到結束標籤的轉移得分 score += self.end_transitions[last_tags] return score def forward(self, emissions: torch.Tensor, tags: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ 計算CRF負對數似然損失 參數: emissions: (seq_len, batch_size, num_labels) 發射得分 tags: (seq_len, batch_size) 真實標籤 mask: (seq_len, batch_size) 用於處理變長序列的遮罩 返回: CRF負對數似然損失 """ # 計算分子(numerator)和分母(denominator)的對數 log_numerator = self._compute_log_numerator(emissions, tags, mask) log_denominator = self._compute_log_denominator(emissions, mask) # 損失是分母減分子 loss = torch.mean(log_denominator - log_numerator) return loss def _viterbi_decode(self, features: torch.Tensor, mask: torch.Tensor) -> List[List[int]]: """Viterbi算法解碼,找出最可能的標籤序列""" seq_len, batch_size, _ = features.shape # 初始化Viterbi變量 log_score = self.start_transitions + features[0] # [batch_size, num_labels] backpointers = torch.zeros((seq_len, batch_size, self.num_labels), dtype=torch.long, device=features.device) # 逐時間步填充 for i in range(1, seq_len): # 計算所有可能的轉移得分 next_score = log_score.unsqueeze(2) + self.transitions + features[i].unsqueeze(1) # 找出每個當前標籤的最佳前一個標籤 next_score, indices = next_score.max(dim=1) # 記錄回溯指針 backpointers[i] = indices # 根據mask更新得分 log_score = torch.where(mask[i].unsqueeze(1), next_score, log_score) # 加上到結束標籤的轉移得分 log_score += self.end_transitions # 找出每個序列的最後一個標籤 seq_lens = mask.sum(dim=0).long() - 1 # 序列長度 # 回溯獲取最佳路徑 best_paths = [] for seq_idx in range(batch_size): # 找出得分最高的最終標籤 best_label = torch.argmax(log_score[seq_idx]).item() best_path = [best_label] # 從後向前回溯 for i in range(seq_lens[seq_idx], 0, -1): best_label = backpointers[i, seq_idx, best_label].item() best_path.insert(0, best_label) best_paths.append(best_path) return best_paths def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> List[List[int]]: """使用Viterbi解碼找出最可能的標籤序列""" # 確保mask是bool類型 if mask.dtype != torch.bool: mask = mask.bool() with torch.no_grad(): return self._viterbi_decode(emissions, mask) class BertCRFForTokenClassification(BertPreTrainedModel): """BERT模型與CRF層結合用於token分類""" def __init__(self, config: BertConfig): super().__init__(config) self.num_labels = config.num_labels # BERT層 self.bert = BertModel(config, add_pooling_layer=False) # Dropout和分類器 classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # CRF層 self.crf = CRF(config.num_labels) # 初始化權重 self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ignore_index: int = -100, **kwargs, ) -> Union[Tuple[torch.Tensor], Dict[str, torch.Tensor]]: """ 使用CRF進行序列標注的前向傳播 參數: labels: 標籤序列,用於計算損失 ignore_index: 忽略的標籤值,通常為-100 """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # BERT前向傳播 outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) # 獲取發射分數 logits = self.classifier(sequence_output) # [batch_size, seq_len, num_labels] loss = None if labels is not None: # 準備CRF所需的輸入格式 # 交換維度:[batch_size, seq_len, num_labels] -> [seq_len, batch_size, num_labels] emissions = logits.transpose(0, 1) # 交換維度:[batch_size, seq_len] -> [seq_len, batch_size] if attention_mask is not None: attention_mask_t = attention_mask.transpose(0, 1).bool() else: attention_mask_t = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device) # 處理ignore_index if ignore_index is not None: labels_mask = (labels != ignore_index) attention_mask_t = attention_mask_t & labels_mask.transpose(0, 1) # 創建一個不包含ignore_index的標籤tensor crf_labels = labels.clone() crf_labels[~labels_mask] = 0 # 將ignore的位置臨時設為0,避免其影響CRF計算 crf_labels_t = crf_labels.transpose(0, 1) else: crf_labels_t = labels.transpose(0, 1) # 計算CRF損失 loss = self.crf( emissions=emissions, tags=crf_labels_t, mask=attention_mask_t ) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def decode( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, **kwargs, ) -> List[List[int]]: """ 解碼最可能的標籤序列 """ # 不計算梯度 with torch.no_grad(): # BERT前向傳播 outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True, **kwargs, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) # 獲取發射分數 logits = self.classifier(sequence_output) # [batch_size, seq_len, num_labels] # 交換維度:[batch_size, seq_len, num_labels] -> [seq_len, batch_size, num_labels] emissions = logits.transpose(0, 1) # 準備遮罩 if attention_mask is not None: mask = attention_mask.transpose(0, 1).bool() else: mask = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device) # 使用Viterbi算法解碼 best_tags = self.crf.decode(emissions, mask) return best_tags class RobertaForTokenClassification(RobertaPreTrainedModel): def __init__(self, config: RobertaConfig): super().__init__(config) self.num_labels = config.num_labels self.roberta = RobertaModel(config, add_pooling_layer=False) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, num_items_in_batch: Optional[torch.Tensor] = None, ignore_index: int = -100, **kwargs, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: - 0 corresponds to a *sentence A* token, - 1 corresponds to a *sentence B* token. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value >= 2. All the value in this tensor should be always < type_vocab_size. [What are token type IDs?](../glossary#token-type-ids) labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.view(-1, self.num_labels) labels = labels.view(-1).to(logits.device) logits = logits.float() loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): def __init__(self, config: DebertaV2Config): super().__init__(config) self.num_labels = config.num_labels self.deberta = DebertaV2Model(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, num_items_in_batch: Optional[torch.Tensor] = None, ignore_index: int = -100, **kwargs, ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.deberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.view(-1, self.num_labels) labels = labels.view(-1).to(logits.device) logits = logits.float() loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions ) class ModernBertForTokenClassification(ModernBertPreTrainedModel): def __init__(self, config: ModernBertConfig): super().__init__(config) self.num_labels = config.num_labels self.model = ModernBertModel(config) self.head = ModernBertPredictionHead(config) self.drop = torch.nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, num_items_in_batch: Optional[torch.Tensor] = None, ignore_index: int = -100, **kwargs, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers perform global attention, while the rest perform local attention. This mask is used to avoid attending to far-away tokens in the local attention layers when not using Flash Attention. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): Indices of the non-padding tokens in the input sequence. Used for unpadding the output. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. max_seqlen (`int`, *optional*): Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. batch_size (`int`, *optional*): Batch size of the input sequences. Used to pad the output tensors. seq_len (`int`, *optional*): Sequence length of the input sequences including padding tokens. Used to pad the output tensors. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = outputs[0] last_hidden_state = self.head(last_hidden_state) last_hidden_state = self.drop(last_hidden_state) logits = self.classifier(last_hidden_state) loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.view(-1, self.num_labels) labels = labels.view(-1).to(logits.device) logits = logits.float() loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class UnmaskingQwen3Attention(Qwen3Attention): """Multi-headed attention without causal mask for bidirectional attention""" def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # Use eager attention as default attention_interface: Callable = eager_attention_forward # Remove causal mask by setting attention_mask to None or creating a non-causal mask # For bidirectional attention, we don't want any masking except padding if attention_mask is not None and 0.0 in attention_mask: # Keep only padding mask if it exists, remove causal part # This allows tokens to attend to future tokens pass else: # If there's no padding, we can set attention_mask to None for full attention attention_mask = None attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class UnmaskingQwen3DecoderLayer(Qwen3DecoderLayer): def __init__(self, config: Qwen3Config, layer_idx: int): super(Qwen3DecoderLayer, self).__init__() self.hidden_size = config.hidden_size self.self_attn = UnmaskingQwen3Attention(config=config, layer_idx=layer_idx) self.mlp = Qwen3MLP(config) self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) class UnmaskingQwen3Model(Qwen3Model): def __init__(self, config: Qwen3Config): super(Qwen3PreTrainedModel, self).__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [UnmaskingQwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen3RotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool = False, ): # Override the causal mask creation to create a non-causal mask # This allows bidirectional attention if attention_mask is None: # If no attention mask is provided, return None to allow full attention return None # If attention_mask is provided, it's likely for padding # Convert it to the right format but without the causal constraint dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min batch_size = input_tensor.shape[0] sequence_length = input_tensor.shape[1] if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2: # Convert 2D padding mask to 4D attention mask expanded_attn_mask = attention_mask[:, None, None, :] expanded_attn_mask = expanded_attn_mask.to(dtype=dtype) expanded_attn_mask = (1.0 - expanded_attn_mask) * min_dtype return expanded_attn_mask # If it's already 4D, return as is return attention_mask class UnmaskingQwen3ForTokenClassification(Qwen3PreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = UnmaskingQwen3Model(config) if getattr(config, "classifier_dropout", None) is not None: classifier_dropout = config.classifier_dropout elif getattr(config, "hidden_dropout", None) is not None: classifier_dropout = config.hidden_dropout else: classifier_dropout = 0.1 self.dropout = nn.Dropout(classifier_dropout) self.score = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value @can_return_tuple def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class UnmaskingQwen2Model(Qwen2Model): """ UnmaskingQwen2Model is a modified version of Qwen2Model that removes the causal mask, allowing bidirectional attention similar to BERT-like models. """ def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool = False, ): """ Override the causal mask creation to create a non-causal (bidirectional) mask. This allows each token to attend to all tokens in the sequence. """ # For flash attention, just return None or the padding mask if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For flex attention, keep the same behavior but without causality if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): # We don't convert to causal mask here return attention_mask return attention_mask # For other attention implementations, create a non-causal mask batch_size = input_tensor.shape[0] sequence_length = input_tensor.shape[1] dtype = input_tensor.dtype # For SlidingWindowCache or StaticCache if isinstance(past_key_values, (SlidingWindowCache, StaticCache)): target_length = past_key_values.get_max_cache_shape() else: # For DynamicCache or no cache target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_key_values.get_seq_length() + sequence_length + 1 if past_key_values is not None else sequence_length ) # Create a non-causal mask (all zeros, allowing full attention) # Instead of using min_dtype to mask out future tokens, we use zeros to allow attention to all positions non_causal_mask = torch.zeros( (batch_size, 1, sequence_length, target_length), dtype=dtype, device=input_tensor.device, ) # If there's a padding attention mask, apply it if attention_mask is not None: if attention_mask.dim() == 2: # Convert 2D attention mask to 4D expanded_mask = attention_mask[:, None, None, :].expand( batch_size, 1, sequence_length, attention_mask.shape[-1] ).to(non_causal_mask.device) # Apply padding mask (0 for tokens to attend to, large negative for padded positions) min_dtype = torch.finfo(dtype).min padding_mask = expanded_mask == 0 non_causal_mask = non_causal_mask.masked_fill(padding_mask, min_dtype) elif attention_mask.dim() == 4: # If already 4D, use as is non_causal_mask = attention_mask return non_causal_mask class UnmaskingQwen2ForTokenClassification(Qwen2PreTrainedModel): """ Qwen2 model with a token classification head on top, but with bidirectional attention. This is achieved by using the UnmaskingQwen2Model which removes the causal mask. """ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels # Use the UnmaskingQwen2Model instead of the standard Qwen2Model self.model = UnmaskingQwen2Model(config) if getattr(config, "classifier_dropout", None) is not None: classifier_dropout = config.classifier_dropout elif getattr(config, "hidden_dropout", None) is not None: classifier_dropout = config.hidden_dropout else: classifier_dropout = 0.1 self.dropout = nn.Dropout(classifier_dropout) self.score = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> TokenClassifierOutput: """ Forward pass for token classification with bidirectional attention. Args: input_ids: Input token IDs attention_mask: Attention mask position_ids: Position IDs past_key_values: Past key values for efficient generation inputs_embeds: Pre-computed input embeddings labels: Token classification labels use_cache: Whether to use cache for efficient generation output_attentions: Whether to output attention weights output_hidden_states: Whether to output hidden states flash_attn_kwargs: Additional arguments for flash attention Returns: TokenClassifierOutput with loss, logits, and optional hidden states and attentions """ outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, **flash_attn_kwargs, ) sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) loss = None if labels is not None: loss = self.loss_function(logits, labels, self.config) return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )