# coding=utf-8 # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional, Tuple import torch from torch import nn try: import torch_npu from torch_npu.contrib import transfer_to_npu if "910" in torch.npu.get_device_name(): NPU_ATTN_INFR = True print("[INFO] torch_npu detected. Using NPU fused infer attention.") else: NPU_ATTN_INFR = False except ImportError: NPU_ATTN_INFR = False from transformers.cache_utils import Cache from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.processing_utils import Unpack from transformers.utils import logging from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaForSequenceClassification, LlamaMLP, LlamaModel, apply_rotary_pos_emb, eager_attention_forward, ) from .configuration_openpangu_dense import PanguEmbeddedConfig logger = logging.get_logger(__name__) class PanguEmbeddedMLP(LlamaMLP): def __init__(self, config): super().__init__(config) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) class PanguEmbeddedAttention(LlamaAttention): def __init__(self, config: PanguEmbeddedConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.bias) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.bias) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] if not self.training and NPU_ATTN_INFR: q_len = input_shape[1] if attention_mask is not None: attention_mask = ~attention_mask.bool() elif q_len > 1: attention_mask = torch.triu(torch.ones([q_len, q_len]), diagonal=1).bool().unsqueeze(0).unsqueeze(0).to(query_states.device) attn_output, _ = torch_npu.npu_fused_infer_attention_score( query_states, key_states, value_states, num_heads=self.num_heads, num_key_value_heads=self.num_key_value_heads, input_layout="BNSD", atten_mask=attention_mask, scale=self.scaling) attn_output = attn_output.transpose(1, 2) attn_weights = None else: attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class PanguEmbeddedDecoderLayer(LlamaDecoderLayer): pass class PanguEmbeddedModel(LlamaModel): pass class PanguEmbeddedForCausalLM(LlamaForCausalLM): pass