|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Callable, Optional, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from transformers.activations import ACT2FN |
|
from transformers.cache_utils import Cache, DynamicCache |
|
from transformers.generation import GenerationMixin |
|
from transformers.integrations import use_kernel_forward_from_hub |
|
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask |
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
from transformers.modeling_layers import ( |
|
GenericForQuestionAnswering, |
|
GenericForSequenceClassification, |
|
GenericForTokenClassification, |
|
GradientCheckpointingLayer, |
|
) |
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
|
from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast |
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
|
from transformers.processing_utils import Unpack |
|
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple |
|
from transformers.utils.generic import OutputRecorder, check_model_inputs |
|
from .configuration_klear import KlearConfig |
|
|
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
|
"""Applies Rotary Position Embedding to the query and key tensors. |
|
|
|
Args: |
|
q (`torch.Tensor`): The query tensor. |
|
k (`torch.Tensor`): The key tensor. |
|
cos (`torch.Tensor`): The cosine part of the rotary embedding. |
|
sin (`torch.Tensor`): The sine part of the rotary embedding. |
|
position_ids (`torch.Tensor`, *optional*): |
|
Deprecated and unused. |
|
unsqueeze_dim (`int`, *optional*, defaults to 1): |
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
|
Returns: |
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
|
""" |
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
return q_embed, k_embed |
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
def eager_attention_forward( |
|
module: nn.Module, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor], |
|
scaling: float, |
|
dropout: float = 0.0, |
|
**kwargs: Unpack[TransformersKwargs], |
|
): |
|
key_states = repeat_kv(key, module.num_key_value_groups) |
|
value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
|
if attention_mask is not None: |
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
|
attn_weights = attn_weights + causal_mask |
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
class KlearAttention(nn.Module): |
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
def __init__(self, config: KlearConfig, layer_idx: int): |
|
super().__init__() |
|
self.config = config |
|
self.layer_idx = layer_idx |
|
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
|
self.scaling = self.head_dim**-0.5 |
|
self.attention_dropout = config.attention_dropout |
|
self.is_causal = True |
|
|
|
self.q_proj = nn.Linear( |
|
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias |
|
) |
|
self.k_proj = nn.Linear( |
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
|
) |
|
self.v_proj = nn.Linear( |
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
|
) |
|
self.o_proj = nn.Linear( |
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
|
) |
|
self.q_norm = KlearRMSNorm(self.head_dim, eps=config.rms_norm_eps) |
|
self.k_norm = KlearRMSNorm(self.head_dim, eps=config.rms_norm_eps) |
|
self.sliding_window = getattr(config, "sliding_window", None) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
position_embeddings: tuple[torch.Tensor, torch.Tensor], |
|
attention_mask: Optional[torch.Tensor], |
|
past_key_value: Optional[Cache] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
|
input_shape = hidden_states.shape[:-1] |
|
hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
|
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) |
|
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) |
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
|
cos, sin = position_embeddings |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
if past_key_value is not None: |
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
attention_interface: Callable = eager_attention_forward |
|
if self.config._attn_implementation != "eager": |
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
|
attn_output, attn_weights = attention_interface( |
|
self, |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
dropout=0.0 if not self.training else self.attention_dropout, |
|
scaling=self.scaling, |
|
sliding_window=self.sliding_window, |
|
**kwargs, |
|
) |
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
|
attn_output = self.o_proj(attn_output) |
|
return attn_output, attn_weights |
|
|
|
|
|
class KlearMLP(nn.Module): |
|
def __init__(self, config, intermediate_size=None): |
|
super().__init__() |
|
self.config = config |
|
self.hidden_size = config.hidden_size |
|
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size |
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
self.act_fn = ACT2FN[config.hidden_act] |
|
|
|
def forward(self, x): |
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
return down_proj |
|
|
|
|
|
class KlearSparseMoeBlock(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.num_experts = config.num_experts |
|
self.top_k = config.num_experts_per_tok |
|
self.norm_topk_prob = config.norm_topk_prob |
|
|
|
|
|
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) |
|
self.experts = nn.ModuleList( |
|
[KlearMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.num_experts)] |
|
) |
|
self.shared_experts = KlearMLP( |
|
config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts |
|
) |
|
|
|
self.coefficient = nn.Linear(config.hidden_size, 2) |
|
self.register_buffer("expert_bias", torch.zeros(self.num_experts, dtype=torch.float32)) |
|
|
|
def forward(self, hidden_states): |
|
residuals = hidden_states |
|
batch_size, sequence_length, hidden_dim = hidden_states.shape |
|
hidden_states = hidden_states.view(-1, hidden_dim) |
|
|
|
router_logits = nn.functional.linear(hidden_states.to(torch.float32), self.gate.weight.to(torch.float32)) |
|
|
|
routing_weights = F.sigmoid(router_logits) |
|
ori_routing_weights = routing_weights |
|
|
|
|
|
biasd_routing_weights = routing_weights + self.expert_bias.unsqueeze(0) |
|
_, selected_experts = torch.topk(biasd_routing_weights, self.top_k, dim=-1) |
|
|
|
|
|
ori_routing_weights = torch.gather(ori_routing_weights, dim=-1, index=selected_experts) |
|
|
|
if self.norm_topk_prob: |
|
ori_routing_weights /= ori_routing_weights.sum(dim=-1, keepdim=True) |
|
|
|
ori_routing_weights = ori_routing_weights.to(hidden_states.dtype) |
|
|
|
final_hidden_states = torch.zeros( |
|
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device |
|
) |
|
|
|
|
|
|
|
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) |
|
|
|
|
|
expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() |
|
for expert_idx in expert_hitted: |
|
expert_layer = self.experts[expert_idx] |
|
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) |
|
|
|
|
|
|
|
|
|
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) |
|
current_hidden_states = expert_layer(current_state) * ori_routing_weights[top_x, idx, None] |
|
|
|
|
|
|
|
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) |
|
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) |
|
|
|
coef = self.coefficient(residuals).softmax(dim=-1) |
|
final_hidden_states = final_hidden_states * coef[..., :1] + self.shared_experts(residuals) * coef[..., 1:] |
|
|
|
return final_hidden_states, router_logits |
|
|
|
|
|
@use_kernel_forward_from_hub("RMSNorm") |
|
class KlearRMSNorm(nn.Module): |
|
def __init__(self, hidden_size, eps=1e-6): |
|
""" |
|
KlearRMSNorm is equivalent to T5LayerNorm |
|
""" |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
self.variance_epsilon = eps |
|
|
|
def forward(self, hidden_states): |
|
input_dtype = hidden_states.dtype |
|
hidden_states = hidden_states.to(torch.float32) |
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) |
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
return self.weight * hidden_states.to(input_dtype) |
|
|
|
def extra_repr(self): |
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
|
class KlearDecoderLayer(GradientCheckpointingLayer): |
|
def __init__(self, config: KlearConfig, layer_idx: int): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
|
|
self.self_attn = KlearAttention(config, layer_idx) |
|
|
|
if (layer_idx not in config.mlp_only_layers) and ( |
|
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 |
|
): |
|
self.mlp = KlearSparseMoeBlock(config) |
|
else: |
|
self.mlp = KlearMLP(config, intermediate_size=config.intermediate_size) |
|
|
|
self.input_layernorm = KlearRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.post_attention_layernorm = KlearRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
output_router_logits: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
""" |
|
Args: |
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size |
|
`(batch, sequence_length)` where padding elements are indicated by 0. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
output_router_logits (`bool`, *optional*): |
|
Whether or not to return the logits of all the routers. They are useful for computing the router loss, |
|
and should not be returned during inference. |
|
use_cache (`bool`, *optional*): |
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
|
(see `past_key_values`). |
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
|
Indices depicting the position of the input sequence tokens in the sequence. |
|
position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): |
|
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, |
|
with `head_dim` being the embedding dimension of each attention head. |
|
kwargs (`dict`, *optional*): |
|
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code |
|
into the model |
|
""" |
|
|
|
residual = hidden_states |
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
hidden_states, self_attn_weights = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
**kwargs, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
|
hidden_states = self.mlp(hidden_states) |
|
if isinstance(hidden_states, tuple): |
|
hidden_states, router_logits = hidden_states |
|
else: |
|
router_logits = None |
|
|
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
if output_router_logits: |
|
outputs += (router_logits,) |
|
|
|
return outputs |
|
|
|
|
|
class KlearRotaryEmbedding(nn.Module): |
|
def __init__(self, config: KlearConfig, device=None): |
|
super().__init__() |
|
|
|
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): |
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
|
else: |
|
self.rope_type = "default" |
|
self.max_seq_len_cached = config.max_position_embeddings |
|
self.original_max_seq_len = config.max_position_embeddings |
|
|
|
self.config = config |
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
self.original_inv_freq = self.inv_freq |
|
|
|
@torch.no_grad() |
|
@dynamic_rope_update |
|
def forward(self, x, position_ids): |
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
|
position_ids_expanded = position_ids[:, None, :].float() |
|
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
|
with torch.autocast(device_type=device_type, enabled=False): |
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos() * self.attention_scaling |
|
sin = emb.sin() * self.attention_scaling |
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
|
@auto_docstring |
|
class KlearPreTrainedModel(PreTrainedModel): |
|
config: KlearConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["KlearDecoderLayer"] |
|
_skip_keys_device_placement = ["past_key_values"] |
|
_supports_flash_attn = True |
|
_supports_sdpa = True |
|
_supports_flex_attn = True |
|
_can_compile_fullgraph = False |
|
_supports_attention_backend = True |
|
_can_record_outputs = { |
|
"router_logits": OutputRecorder(KlearSparseMoeBlock, index=1), |
|
"hidden_states": KlearDecoderLayer, |
|
"attentions": KlearAttention, |
|
} |
|
|
|
|
|
@auto_docstring |
|
class KlearModel(KlearPreTrainedModel): |
|
def __init__(self, config: KlearConfig): |
|
super().__init__(config) |
|
self.padding_idx = config.pad_token_id |
|
self.vocab_size = config.vocab_size |
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
self.layers = nn.ModuleList( |
|
[KlearDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
) |
|
self.norm = KlearRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.rotary_emb = KlearRotaryEmbedding(config=config) |
|
self.gradient_checkpointing = False |
|
|
|
|
|
self.post_init() |
|
|
|
@check_model_inputs |
|
@auto_docstring |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[list[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_router_logits: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs: Unpack[TransformersKwargs], |
|
) -> MoeModelOutputWithPast: |
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_router_logits = ( |
|
output_router_logits if output_router_logits is not None else self.config.output_router_logits |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
if use_cache and past_key_values is None: |
|
past_key_values = DynamicCache() |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
if cache_position is None: |
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
cache_position = torch.arange( |
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
) |
|
if position_ids is None: |
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask |
|
causal_mask = mask_function( |
|
config=self.config, |
|
input_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
cache_position=cache_position, |
|
past_key_values=past_key_values, |
|
position_ids=position_ids, |
|
) |
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
all_router_logits = () if output_router_logits else None |
|
|
|
for decoder_layer in self.layers: |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=causal_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
output_router_logits=output_router_logits, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
**kwargs, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
if output_router_logits: |
|
all_router_logits += (layer_outputs[-1],) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
return MoeModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=past_key_values, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
router_logits=all_router_logits, |
|
) |
|
|
|
|
|
@auto_docstring |
|
class KlearMoeForCausalLM(KlearPreTrainedModel, GenerationMixin): |
|
_tied_weights_keys = ["lm_head.weight"] |
|
_tp_plan = {"lm_head": "colwise_rep"} |
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = KlearModel(config) |
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
self.router_aux_loss_coef = config.router_aux_loss_coef |
|
self.num_experts = config.num_experts |
|
self.num_experts_per_tok = config.num_experts_per_tok |
|
self.moe_aux_loss_coeff = getattr(config, "moe_aux_loss_coeff", 1.0) |
|
|
|
|
|
self.post_init() |
|
|
|
def set_decoder(self, decoder): |
|
self.model = decoder |
|
|
|
def get_decoder(self): |
|
return self.model |
|
|
|
@can_return_tuple |
|
@auto_docstring |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[list[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_router_logits: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
**kwargs: Unpack[TransformersKwargs], |
|
) -> MoeCausalLMOutputWithPast: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, KlearMoeForCausalLM |
|
|
|
>>> model = KlearMoeForCausalLM.from_pretrained("Klear-kwaii/Klear-MoE") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("Klear-kwaii/Klear-MoE") |
|
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?" |
|
>>> inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
>>> # Generate |
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
|
```""" |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_router_logits = ( |
|
output_router_logits if output_router_logits is not None else self.config.output_router_logits |
|
) |
|
|
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
|
|
|
|
outputs: MoeModelOutputWithPast = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
output_router_logits=output_router_logits, |
|
cache_position=cache_position, |
|
**kwargs, |
|
) |
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) |
|
|
|
aux_loss = None |
|
if output_router_logits: |
|
pass |
|
|
|
return MoeCausalLMOutputWithPast( |
|
loss=loss, |
|
aux_loss=aux_loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
router_logits=outputs.router_logits, |
|
) |
|
|
|
|
|
class KlearForSequenceClassification(GenericForSequenceClassification, KlearPreTrainedModel): |
|
pass |
|
|
|
|
|
class KlearForTokenClassification(GenericForTokenClassification, KlearPreTrainedModel): |
|
pass |
|
|
|
|
|
class KlearForQuestionAnswering(GenericForQuestionAnswering, KlearPreTrainedModel): |
|
base_model_prefix = "transformer" |
|
|
|
|
|
__all__ = [ |
|
"KlearMoeForCausalLM", |
|
"KlearForQuestionAnswering", |
|
"KlearModel", |
|
"KlearPreTrainedModel", |
|
"KlearForSequenceClassification", |
|
"KlearForTokenClassification", |
|
] |
|
|