fix code
Browse files- configuration_minimax_m1.py +14 -14
- modeling_minimax_m1.py +57 -57
configuration_minimax_m1.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
from transformers.utils import logging
|
|
@@ -7,11 +7,11 @@ from transformers.utils import logging
|
|
| 7 |
logger = logging.get_logger(__name__)
|
| 8 |
|
| 9 |
|
| 10 |
-
class
|
| 11 |
r"""
|
| 12 |
-
This is the configuration class to store the configuration of a [`
|
| 13 |
-
|
| 14 |
-
with the defaults will yield a similar configuration to that of the
|
| 15 |
|
| 16 |
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 17 |
documentation from [`PretrainedConfig`] for more information.
|
|
@@ -19,8 +19,8 @@ class MiniMaxText01Config(PretrainedConfig):
|
|
| 19 |
|
| 20 |
Args:
|
| 21 |
vocab_size (`int`, *optional*, defaults to 32000):
|
| 22 |
-
Vocabulary size of the
|
| 23 |
-
`inputs_ids` passed when calling [`
|
| 24 |
hidden_size (`int`, *optional*, defaults to 4096):
|
| 25 |
Dimension of the hidden representations.
|
| 26 |
intermediate_size (`int`, *optional*, defaults to 14336):
|
|
@@ -39,7 +39,7 @@ class MiniMaxText01Config(PretrainedConfig):
|
|
| 39 |
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 40 |
The non-linear activation function (function or string) in the decoder.
|
| 41 |
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
|
| 42 |
-
The maximum sequence length that this model might ever be used with.
|
| 43 |
allows sequence of up to 4096*32 tokens.
|
| 44 |
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 45 |
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
@@ -76,19 +76,19 @@ class MiniMaxText01Config(PretrainedConfig):
|
|
| 76 |
Amount of noise to add to the router.
|
| 77 |
|
| 78 |
```python
|
| 79 |
-
>>> from transformers import
|
| 80 |
|
| 81 |
-
>>> # Initializing a
|
| 82 |
-
>>> configuration =
|
| 83 |
|
| 84 |
-
>>> # Initializing a model from the
|
| 85 |
-
>>> model =
|
| 86 |
|
| 87 |
>>> # Accessing the model configuration
|
| 88 |
>>> configuration = model.config
|
| 89 |
```"""
|
| 90 |
|
| 91 |
-
model_type = "
|
| 92 |
keys_to_ignore_at_inference = ["past_key_values"]
|
| 93 |
|
| 94 |
def __init__(
|
|
|
|
| 1 |
+
""" MiniMaxM1 model configuration"""
|
| 2 |
|
| 3 |
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
from transformers.utils import logging
|
|
|
|
| 7 |
logger = logging.get_logger(__name__)
|
| 8 |
|
| 9 |
|
| 10 |
+
class MiniMaxM1Config(PretrainedConfig):
|
| 11 |
r"""
|
| 12 |
+
This is the configuration class to store the configuration of a [`MiniMaxM1Model`]. It is used to instantiate an
|
| 13 |
+
MiniMaxM1 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 14 |
+
with the defaults will yield a similar configuration to that of the MiniMaxM1.
|
| 15 |
|
| 16 |
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 17 |
documentation from [`PretrainedConfig`] for more information.
|
|
|
|
| 19 |
|
| 20 |
Args:
|
| 21 |
vocab_size (`int`, *optional*, defaults to 32000):
|
| 22 |
+
Vocabulary size of the MiniMaxM1 model. Defines the number of different tokens that can be represented by the
|
| 23 |
+
`inputs_ids` passed when calling [`MiniMaxM1Model`]
|
| 24 |
hidden_size (`int`, *optional*, defaults to 4096):
|
| 25 |
Dimension of the hidden representations.
|
| 26 |
intermediate_size (`int`, *optional*, defaults to 14336):
|
|
|
|
| 39 |
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 40 |
The non-linear activation function (function or string) in the decoder.
|
| 41 |
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
|
| 42 |
+
The maximum sequence length that this model might ever be used with. MiniMaxM1's sliding window attention
|
| 43 |
allows sequence of up to 4096*32 tokens.
|
| 44 |
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 45 |
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
|
|
| 76 |
Amount of noise to add to the router.
|
| 77 |
|
| 78 |
```python
|
| 79 |
+
>>> from transformers import MiniMaxM1Model, MiniMaxM1Config
|
| 80 |
|
| 81 |
+
>>> # Initializing a MiniMaxM1 style configuration
|
| 82 |
+
>>> configuration = MiniMaxM1Config()
|
| 83 |
|
| 84 |
+
>>> # Initializing a model from the MiniMaxM1 style configuration
|
| 85 |
+
>>> model = MiniMaxM1Model(configuration)
|
| 86 |
|
| 87 |
>>> # Accessing the model configuration
|
| 88 |
>>> configuration = model.config
|
| 89 |
```"""
|
| 90 |
|
| 91 |
+
model_type = "MiniMaxM1"
|
| 92 |
keys_to_ignore_at_inference = ["past_key_values"]
|
| 93 |
|
| 94 |
def __init__(
|
modeling_minimax_m1.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
""" PyTorch
|
| 2 |
import inspect
|
| 3 |
import math
|
| 4 |
import warnings
|
|
@@ -31,7 +31,7 @@ from transformers.utils import (
|
|
| 31 |
replace_return_docstrings,
|
| 32 |
)
|
| 33 |
from transformers.utils.import_utils import is_torch_fx_available
|
| 34 |
-
from .configuration_minimax_m1 import
|
| 35 |
|
| 36 |
if is_flash_attn_2_available():
|
| 37 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
@@ -52,7 +52,7 @@ BLOCK = 256
|
|
| 52 |
|
| 53 |
logger = logging.get_logger(__name__)
|
| 54 |
|
| 55 |
-
_CONFIG_FOR_DOC = "
|
| 56 |
|
| 57 |
|
| 58 |
def get_activation_fn(activation):
|
|
@@ -207,8 +207,8 @@ class GLU(nn.Module):
|
|
| 207 |
return output
|
| 208 |
|
| 209 |
|
| 210 |
-
class
|
| 211 |
-
def __init__(self, config:
|
| 212 |
super().__init__()
|
| 213 |
bias = False
|
| 214 |
self.hidden_size = config.hidden_size
|
|
@@ -217,7 +217,7 @@ class MiniMaxText01LightningAttention(nn.Module):
|
|
| 217 |
|
| 218 |
self.out_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=bias)
|
| 219 |
self.act = get_activation_fn(config.hidden_act)
|
| 220 |
-
self.norm =
|
| 221 |
|
| 222 |
self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias)
|
| 223 |
self.output_gate = nn.Linear(self.hidden_size, self.head_dim * self.num_heads, bias=bias)
|
|
@@ -338,11 +338,11 @@ class MiniMaxText01LightningAttention(nn.Module):
|
|
| 338 |
return output, attn_weights, kv
|
| 339 |
|
| 340 |
|
| 341 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->
|
| 342 |
-
class
|
| 343 |
def __init__(self, hidden_size, eps=1e-6):
|
| 344 |
"""
|
| 345 |
-
|
| 346 |
"""
|
| 347 |
super().__init__()
|
| 348 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
@@ -356,8 +356,8 @@ class MiniMaxText01RMSNorm(nn.Module):
|
|
| 356 |
return self.weight * hidden_states.to(input_dtype)
|
| 357 |
|
| 358 |
|
| 359 |
-
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->
|
| 360 |
-
class
|
| 361 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 362 |
super().__init__()
|
| 363 |
|
|
@@ -447,14 +447,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
| 447 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 448 |
|
| 449 |
|
| 450 |
-
# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->
|
| 451 |
-
class
|
| 452 |
"""
|
| 453 |
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
| 454 |
and "Generating Long Sequences with Sparse Transformers".
|
| 455 |
"""
|
| 456 |
|
| 457 |
-
def __init__(self, config:
|
| 458 |
super().__init__()
|
| 459 |
self.config = config
|
| 460 |
self.layer_idx = layer_idx
|
|
@@ -481,7 +481,7 @@ class MiniMaxText01Attention(nn.Module):
|
|
| 481 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 482 |
self.rotary_dim = getattr(config, 'rotary_dim', self.head_dim)
|
| 483 |
|
| 484 |
-
self.rotary_emb =
|
| 485 |
self.rotary_dim,
|
| 486 |
max_position_embeddings=self.max_position_embeddings,
|
| 487 |
base=self.rope_theta,
|
|
@@ -572,10 +572,10 @@ class MiniMaxText01Attention(nn.Module):
|
|
| 572 |
return attn_output, attn_weights, past_key_value
|
| 573 |
|
| 574 |
|
| 575 |
-
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->
|
| 576 |
-
class
|
| 577 |
"""
|
| 578 |
-
|
| 579 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 580 |
flash attention and deal with padding tokens in case the input contains any of them.
|
| 581 |
"""
|
|
@@ -836,7 +836,7 @@ class MiniMaxText01FlashAttention2(MiniMaxText01Attention):
|
|
| 836 |
)
|
| 837 |
|
| 838 |
|
| 839 |
-
class
|
| 840 |
def __init__(self, config):
|
| 841 |
super().__init__()
|
| 842 |
self.config = config
|
|
@@ -852,8 +852,8 @@ class MiniMaxText01MLP(nn.Module):
|
|
| 852 |
return down_proj
|
| 853 |
|
| 854 |
|
| 855 |
-
class
|
| 856 |
-
def __init__(self, config:
|
| 857 |
super().__init__()
|
| 858 |
self.ffn_dim = config.intermediate_size
|
| 859 |
self.hidden_dim = config.hidden_size
|
|
@@ -870,15 +870,15 @@ class MiniMaxText01BlockSparseTop2MLP(nn.Module):
|
|
| 870 |
return current_hidden_states
|
| 871 |
|
| 872 |
|
| 873 |
-
class
|
| 874 |
def __init__(self, *args, **kwargs):
|
| 875 |
logger.warning_once(
|
| 876 |
-
"
|
| 877 |
)
|
| 878 |
super().__init__(*args, **kwargs)
|
| 879 |
|
| 880 |
|
| 881 |
-
class
|
| 882 |
"""
|
| 883 |
This implementation is
|
| 884 |
strictly equivalent to standard MoE with full capacity (no
|
|
@@ -900,7 +900,7 @@ class MiniMaxText01SparseMoeBlock(nn.Module):
|
|
| 900 |
# gating
|
| 901 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
| 902 |
|
| 903 |
-
self.experts = nn.ModuleList([
|
| 904 |
|
| 905 |
# Jitter parameters
|
| 906 |
self.jitter_noise = config.router_jitter_noise
|
|
@@ -946,8 +946,8 @@ class MiniMaxText01SparseMoeBlock(nn.Module):
|
|
| 946 |
return final_hidden_states, router_logits
|
| 947 |
|
| 948 |
|
| 949 |
-
class
|
| 950 |
-
def __init__(self, config:
|
| 951 |
super().__init__()
|
| 952 |
self.config = config
|
| 953 |
self.hidden_size = config.hidden_size
|
|
@@ -956,9 +956,9 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
|
| 956 |
|
| 957 |
self.layer_idx = layer_idx
|
| 958 |
|
| 959 |
-
self.block_sparse_moe =
|
| 960 |
-
self.input_layernorm =
|
| 961 |
-
self.post_attention_layernorm =
|
| 962 |
|
| 963 |
self.postnorm = getattr(config, 'postnorm', False)
|
| 964 |
self.layernorm_attention_alpha = getattr(config, 'layernorm_linear_attention_alpha', 1) \
|
|
@@ -972,14 +972,14 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
|
| 972 |
self.shared_moe = False
|
| 973 |
if shared_intermediate > 0:
|
| 974 |
self.shared_moe = True
|
| 975 |
-
self.shared_mlp =
|
| 976 |
self.coefficient = torch.nn.Linear(self.hidden_size, 1, bias=False)
|
| 977 |
|
| 978 |
def build_attn(self, config, layer_idx):
|
| 979 |
if config.attention_type == 0:
|
| 980 |
-
Attention_module =
|
| 981 |
else:
|
| 982 |
-
Attention_module =
|
| 983 |
|
| 984 |
return Attention_module(
|
| 985 |
config,
|
|
@@ -1081,7 +1081,7 @@ MIXTRAL_START_DOCSTRING = r"""
|
|
| 1081 |
and behavior.
|
| 1082 |
|
| 1083 |
Parameters:
|
| 1084 |
-
config ([`
|
| 1085 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 1086 |
load the weights associated with the model, only the configuration. Check out the
|
| 1087 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
@@ -1089,15 +1089,15 @@ MIXTRAL_START_DOCSTRING = r"""
|
|
| 1089 |
|
| 1090 |
|
| 1091 |
@add_start_docstrings(
|
| 1092 |
-
"The bare
|
| 1093 |
MIXTRAL_START_DOCSTRING,
|
| 1094 |
)
|
| 1095 |
-
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->
|
| 1096 |
-
class
|
| 1097 |
-
config_class =
|
| 1098 |
base_model_prefix = "model"
|
| 1099 |
supports_gradient_checkpointing = True
|
| 1100 |
-
_no_split_modules = ["
|
| 1101 |
_skip_keys_device_placement = "past_key_values"
|
| 1102 |
_supports_flash_attn_2 = True
|
| 1103 |
_supports_sdpa = True
|
|
@@ -1182,19 +1182,19 @@ MIXTRAL_INPUTS_DOCSTRING = r"""
|
|
| 1182 |
|
| 1183 |
|
| 1184 |
@add_start_docstrings(
|
| 1185 |
-
"The bare
|
| 1186 |
MIXTRAL_START_DOCSTRING,
|
| 1187 |
)
|
| 1188 |
-
# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->
|
| 1189 |
-
class
|
| 1190 |
"""
|
| 1191 |
-
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`
|
| 1192 |
|
| 1193 |
Args:
|
| 1194 |
-
config:
|
| 1195 |
"""
|
| 1196 |
|
| 1197 |
-
def __init__(self, config:
|
| 1198 |
super().__init__(config)
|
| 1199 |
self.padding_idx = config.pad_token_id
|
| 1200 |
self.vocab_size = config.vocab_size
|
|
@@ -1212,10 +1212,10 @@ class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
|
|
| 1212 |
else:
|
| 1213 |
_config._attn_implementation = config_copy._attn_implementation
|
| 1214 |
_config.attention_type = 1
|
| 1215 |
-
self.layers.append(
|
| 1216 |
|
| 1217 |
self._attn_implementation = config_copy._attn_implementation
|
| 1218 |
-
self.norm =
|
| 1219 |
|
| 1220 |
self.gradient_checkpointing = False
|
| 1221 |
self.slopes = self._build_slope_tensor(config.num_attention_heads)
|
|
@@ -1327,7 +1327,7 @@ class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
|
|
| 1327 |
if is_padding_right:
|
| 1328 |
raise ValueError(
|
| 1329 |
"You are attempting to perform batched generation with padding_side='right'"
|
| 1330 |
-
" this may lead to unexpected behaviour for Flash Attention version of
|
| 1331 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
| 1332 |
)
|
| 1333 |
slope_rates = [self.slopes.to(default_device) for _ in range(len(self.layers))]
|
|
@@ -1401,12 +1401,12 @@ class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
|
|
| 1401 |
)
|
| 1402 |
|
| 1403 |
|
| 1404 |
-
class
|
| 1405 |
_tied_weights_keys = ["lm_head.weight"]
|
| 1406 |
|
| 1407 |
def __init__(self, config):
|
| 1408 |
super().__init__(config)
|
| 1409 |
-
self.model =
|
| 1410 |
self.vocab_size = config.vocab_size
|
| 1411 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1412 |
self.router_aux_loss_coef = config.router_aux_loss_coef
|
|
@@ -1462,9 +1462,9 @@ class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel):
|
|
| 1462 |
Example:
|
| 1463 |
|
| 1464 |
```python
|
| 1465 |
-
>>> from transformers import AutoTokenizer,
|
| 1466 |
|
| 1467 |
-
>>> model =
|
| 1468 |
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_WEIGHTS)
|
| 1469 |
|
| 1470 |
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
@@ -1579,9 +1579,9 @@ class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel):
|
|
| 1579 |
|
| 1580 |
@add_start_docstrings(
|
| 1581 |
"""
|
| 1582 |
-
The
|
| 1583 |
|
| 1584 |
-
[`
|
| 1585 |
(e.g. GPT-2) do.
|
| 1586 |
|
| 1587 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
|
@@ -1592,12 +1592,12 @@ class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel):
|
|
| 1592 |
""",
|
| 1593 |
MIXTRAL_START_DOCSTRING,
|
| 1594 |
)
|
| 1595 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->
|
| 1596 |
-
class
|
| 1597 |
def __init__(self, config):
|
| 1598 |
super().__init__(config)
|
| 1599 |
self.num_labels = config.num_labels
|
| 1600 |
-
self.model =
|
| 1601 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 1602 |
|
| 1603 |
# Initialize weights and apply final processing
|
|
|
|
| 1 |
+
""" PyTorch MiniMaxM1 model."""
|
| 2 |
import inspect
|
| 3 |
import math
|
| 4 |
import warnings
|
|
|
|
| 31 |
replace_return_docstrings,
|
| 32 |
)
|
| 33 |
from transformers.utils.import_utils import is_torch_fx_available
|
| 34 |
+
from .configuration_minimax_m1 import MiniMaxM1Config
|
| 35 |
|
| 36 |
if is_flash_attn_2_available():
|
| 37 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
|
| 52 |
|
| 53 |
logger = logging.get_logger(__name__)
|
| 54 |
|
| 55 |
+
_CONFIG_FOR_DOC = "MiniMaxM1Config"
|
| 56 |
|
| 57 |
|
| 58 |
def get_activation_fn(activation):
|
|
|
|
| 207 |
return output
|
| 208 |
|
| 209 |
|
| 210 |
+
class MiniMaxM1LightningAttention(nn.Module):
|
| 211 |
+
def __init__(self, config: MiniMaxM1Config, layer_idx: Optional[int] = None):
|
| 212 |
super().__init__()
|
| 213 |
bias = False
|
| 214 |
self.hidden_size = config.hidden_size
|
|
|
|
| 217 |
|
| 218 |
self.out_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=bias)
|
| 219 |
self.act = get_activation_fn(config.hidden_act)
|
| 220 |
+
self.norm = MiniMaxM1RMSNorm(self.head_dim * self.num_heads)
|
| 221 |
|
| 222 |
self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias)
|
| 223 |
self.output_gate = nn.Linear(self.hidden_size, self.head_dim * self.num_heads, bias=bias)
|
|
|
|
| 338 |
return output, attn_weights, kv
|
| 339 |
|
| 340 |
|
| 341 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxM1
|
| 342 |
+
class MiniMaxM1RMSNorm(nn.Module):
|
| 343 |
def __init__(self, hidden_size, eps=1e-6):
|
| 344 |
"""
|
| 345 |
+
MiniMaxM1RMSNorm is equivalent to T5LayerNorm
|
| 346 |
"""
|
| 347 |
super().__init__()
|
| 348 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
|
|
| 356 |
return self.weight * hidden_states.to(input_dtype)
|
| 357 |
|
| 358 |
|
| 359 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->MiniMaxM1
|
| 360 |
+
class MiniMaxM1RotaryEmbedding(nn.Module):
|
| 361 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 362 |
super().__init__()
|
| 363 |
|
|
|
|
| 447 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 448 |
|
| 449 |
|
| 450 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->MiniMaxM1
|
| 451 |
+
class MiniMaxM1Attention(nn.Module):
|
| 452 |
"""
|
| 453 |
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
| 454 |
and "Generating Long Sequences with Sparse Transformers".
|
| 455 |
"""
|
| 456 |
|
| 457 |
+
def __init__(self, config: MiniMaxM1Config, layer_idx: Optional[int] = None):
|
| 458 |
super().__init__()
|
| 459 |
self.config = config
|
| 460 |
self.layer_idx = layer_idx
|
|
|
|
| 481 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 482 |
self.rotary_dim = getattr(config, 'rotary_dim', self.head_dim)
|
| 483 |
|
| 484 |
+
self.rotary_emb = MiniMaxM1RotaryEmbedding(
|
| 485 |
self.rotary_dim,
|
| 486 |
max_position_embeddings=self.max_position_embeddings,
|
| 487 |
base=self.rope_theta,
|
|
|
|
| 572 |
return attn_output, attn_weights, past_key_value
|
| 573 |
|
| 574 |
|
| 575 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->MiniMaxM1
|
| 576 |
+
class MiniMaxM1FlashAttention2(MiniMaxM1Attention):
|
| 577 |
"""
|
| 578 |
+
MiniMaxM1 flash attention module. This module inherits from `MiniMaxM1Attention` as the weights of the module stays
|
| 579 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 580 |
flash attention and deal with padding tokens in case the input contains any of them.
|
| 581 |
"""
|
|
|
|
| 836 |
)
|
| 837 |
|
| 838 |
|
| 839 |
+
class MiniMaxM1MLP(nn.Module):
|
| 840 |
def __init__(self, config):
|
| 841 |
super().__init__()
|
| 842 |
self.config = config
|
|
|
|
| 852 |
return down_proj
|
| 853 |
|
| 854 |
|
| 855 |
+
class MiniMaxM1BlockSparseTop2MLP(nn.Module):
|
| 856 |
+
def __init__(self, config: MiniMaxM1Config):
|
| 857 |
super().__init__()
|
| 858 |
self.ffn_dim = config.intermediate_size
|
| 859 |
self.hidden_dim = config.hidden_size
|
|
|
|
| 870 |
return current_hidden_states
|
| 871 |
|
| 872 |
|
| 873 |
+
class MiniMaxM1BLockSparseTop2MLP(MiniMaxM1BlockSparseTop2MLP):
|
| 874 |
def __init__(self, *args, **kwargs):
|
| 875 |
logger.warning_once(
|
| 876 |
+
"MiniMaxM1BLockSparseTop2MLP is deprecated by MiniMaxM1BlockSparseTop2MLP and will be removed in v4.40."
|
| 877 |
)
|
| 878 |
super().__init__(*args, **kwargs)
|
| 879 |
|
| 880 |
|
| 881 |
+
class MiniMaxM1SparseMoeBlock(nn.Module):
|
| 882 |
"""
|
| 883 |
This implementation is
|
| 884 |
strictly equivalent to standard MoE with full capacity (no
|
|
|
|
| 900 |
# gating
|
| 901 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
| 902 |
|
| 903 |
+
self.experts = nn.ModuleList([MiniMaxM1BlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
| 904 |
|
| 905 |
# Jitter parameters
|
| 906 |
self.jitter_noise = config.router_jitter_noise
|
|
|
|
| 946 |
return final_hidden_states, router_logits
|
| 947 |
|
| 948 |
|
| 949 |
+
class MiniMaxM1DecoderLayer(nn.Module):
|
| 950 |
+
def __init__(self, config: MiniMaxM1Config, layer_idx: int):
|
| 951 |
super().__init__()
|
| 952 |
self.config = config
|
| 953 |
self.hidden_size = config.hidden_size
|
|
|
|
| 956 |
|
| 957 |
self.layer_idx = layer_idx
|
| 958 |
|
| 959 |
+
self.block_sparse_moe = MiniMaxM1SparseMoeBlock(config)
|
| 960 |
+
self.input_layernorm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 961 |
+
self.post_attention_layernorm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 962 |
|
| 963 |
self.postnorm = getattr(config, 'postnorm', False)
|
| 964 |
self.layernorm_attention_alpha = getattr(config, 'layernorm_linear_attention_alpha', 1) \
|
|
|
|
| 972 |
self.shared_moe = False
|
| 973 |
if shared_intermediate > 0:
|
| 974 |
self.shared_moe = True
|
| 975 |
+
self.shared_mlp = MiniMaxM1MLP(config)
|
| 976 |
self.coefficient = torch.nn.Linear(self.hidden_size, 1, bias=False)
|
| 977 |
|
| 978 |
def build_attn(self, config, layer_idx):
|
| 979 |
if config.attention_type == 0:
|
| 980 |
+
Attention_module = MiniMaxM1LightningAttention
|
| 981 |
else:
|
| 982 |
+
Attention_module = MiniMaxM1FlashAttention2
|
| 983 |
|
| 984 |
return Attention_module(
|
| 985 |
config,
|
|
|
|
| 1081 |
and behavior.
|
| 1082 |
|
| 1083 |
Parameters:
|
| 1084 |
+
config ([`MiniMaxM1Config`]):
|
| 1085 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 1086 |
load the weights associated with the model, only the configuration. Check out the
|
| 1087 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
|
|
| 1089 |
|
| 1090 |
|
| 1091 |
@add_start_docstrings(
|
| 1092 |
+
"The bare MiniMaxM1 Model outputting raw hidden-states without any specific head on top.",
|
| 1093 |
MIXTRAL_START_DOCSTRING,
|
| 1094 |
)
|
| 1095 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->MiniMaxM1
|
| 1096 |
+
class MiniMaxM1PreTrainedModel(PreTrainedModel):
|
| 1097 |
+
config_class = MiniMaxM1Config
|
| 1098 |
base_model_prefix = "model"
|
| 1099 |
supports_gradient_checkpointing = True
|
| 1100 |
+
_no_split_modules = ["MiniMaxM1DecoderLayer"]
|
| 1101 |
_skip_keys_device_placement = "past_key_values"
|
| 1102 |
_supports_flash_attn_2 = True
|
| 1103 |
_supports_sdpa = True
|
|
|
|
| 1182 |
|
| 1183 |
|
| 1184 |
@add_start_docstrings(
|
| 1185 |
+
"The bare MiniMaxM1 Model outputting raw hidden-states without any specific head on top.",
|
| 1186 |
MIXTRAL_START_DOCSTRING,
|
| 1187 |
)
|
| 1188 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->MiniMaxM1
|
| 1189 |
+
class MiniMaxM1Model(MiniMaxM1PreTrainedModel):
|
| 1190 |
"""
|
| 1191 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniMaxM1DecoderLayer`]
|
| 1192 |
|
| 1193 |
Args:
|
| 1194 |
+
config: MiniMaxM1Config
|
| 1195 |
"""
|
| 1196 |
|
| 1197 |
+
def __init__(self, config: MiniMaxM1Config):
|
| 1198 |
super().__init__(config)
|
| 1199 |
self.padding_idx = config.pad_token_id
|
| 1200 |
self.vocab_size = config.vocab_size
|
|
|
|
| 1212 |
else:
|
| 1213 |
_config._attn_implementation = config_copy._attn_implementation
|
| 1214 |
_config.attention_type = 1
|
| 1215 |
+
self.layers.append(MiniMaxM1DecoderLayer(_config, i))
|
| 1216 |
|
| 1217 |
self._attn_implementation = config_copy._attn_implementation
|
| 1218 |
+
self.norm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1219 |
|
| 1220 |
self.gradient_checkpointing = False
|
| 1221 |
self.slopes = self._build_slope_tensor(config.num_attention_heads)
|
|
|
|
| 1327 |
if is_padding_right:
|
| 1328 |
raise ValueError(
|
| 1329 |
"You are attempting to perform batched generation with padding_side='right'"
|
| 1330 |
+
" this may lead to unexpected behaviour for Flash Attention version of MiniMaxM1. Make sure to "
|
| 1331 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
| 1332 |
)
|
| 1333 |
slope_rates = [self.slopes.to(default_device) for _ in range(len(self.layers))]
|
|
|
|
| 1401 |
)
|
| 1402 |
|
| 1403 |
|
| 1404 |
+
class MiniMaxM1ForCausalLM(MiniMaxM1PreTrainedModel):
|
| 1405 |
_tied_weights_keys = ["lm_head.weight"]
|
| 1406 |
|
| 1407 |
def __init__(self, config):
|
| 1408 |
super().__init__(config)
|
| 1409 |
+
self.model = MiniMaxM1Model(config)
|
| 1410 |
self.vocab_size = config.vocab_size
|
| 1411 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1412 |
self.router_aux_loss_coef = config.router_aux_loss_coef
|
|
|
|
| 1462 |
Example:
|
| 1463 |
|
| 1464 |
```python
|
| 1465 |
+
>>> from transformers import AutoTokenizer, MiniMaxM1ForCausalLM
|
| 1466 |
|
| 1467 |
+
>>> model = MiniMaxM1ForCausalLM.from_pretrained(PATH_TO_WEIGHTS)
|
| 1468 |
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_WEIGHTS)
|
| 1469 |
|
| 1470 |
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
|
|
| 1579 |
|
| 1580 |
@add_start_docstrings(
|
| 1581 |
"""
|
| 1582 |
+
The MiniMaxM1 Model transformer with a sequence classification head on top (linear layer).
|
| 1583 |
|
| 1584 |
+
[`MiniMaxM1ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
| 1585 |
(e.g. GPT-2) do.
|
| 1586 |
|
| 1587 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
|
|
|
| 1592 |
""",
|
| 1593 |
MIXTRAL_START_DOCSTRING,
|
| 1594 |
)
|
| 1595 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->MiniMaxM1, LLAMA->MIXTRAL
|
| 1596 |
+
class MiniMaxM1ForSequenceClassification(MiniMaxM1PreTrainedModel):
|
| 1597 |
def __init__(self, config):
|
| 1598 |
super().__init__(config)
|
| 1599 |
self.num_labels = config.num_labels
|
| 1600 |
+
self.model = MiniMaxM1Model(config)
|
| 1601 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 1602 |
|
| 1603 |
# Initialize weights and apply final processing
|