Zhiding commited on
Commit
b74958e
·
1 Parent(s): b85eda5
all_results.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 1.0,
3
+ "train_loss": 0.06401226358036113,
4
+ "train_runtime": 6205.6143,
5
+ "train_samples": 9129380,
6
+ "train_samples_per_second": 1471.148,
7
+ "train_steps_per_second": 1.874
8
+ }
config.json ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "_name_or_path": "./work_dirs/commercial_eagle_128gpus_bs1024_stage1_ptv1_siglip_llama3_2_3B",
4
+ "architectures": [
5
+ "Eagle2ChatModel"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_eagle_chat.Eagle2ChatConfig",
9
+ "AutoModel": "modeling_eagle_chat.Eagle2ChatModel",
10
+ "AutoModelForCausalLM": "modeling_eagle_chat.Eagle2ChatModel"
11
+ },
12
+ "downsample_ratio": 0.5,
13
+ "dynamic_image_size": true,
14
+ "force_image_size": 448,
15
+ "llm_config": {
16
+ "_name_or_path": "./pretrained/Llama-3_2-3B-Instruct",
17
+ "add_cross_attention": false,
18
+ "architectures": [
19
+ "LlamaForCausalLM"
20
+ ],
21
+ "attention_bias": false,
22
+ "attention_dropout": 0.0,
23
+ "attn_implementation": "flash_attention_2",
24
+ "auto_map": {
25
+ "AutoConfig": "configuration_llama.LlamaConfig",
26
+ "AutoModel": "modeling_llama.LlamaModel",
27
+ "AutoModelForCausalLM": "modeling_llama.LlamaForCausalLM"
28
+ },
29
+ "bad_words_ids": null,
30
+ "begin_suppress_tokens": null,
31
+ "bos_token_id": 128000,
32
+ "chunk_size_feed_forward": 0,
33
+ "cross_attention_hidden_size": null,
34
+ "decoder_start_token_id": null,
35
+ "diversity_penalty": 0.0,
36
+ "do_sample": false,
37
+ "early_stopping": false,
38
+ "encoder_no_repeat_ngram_size": 0,
39
+ "eos_token_id": [
40
+ 128001,
41
+ 128008,
42
+ 128009
43
+ ],
44
+ "exponential_decay_length_penalty": null,
45
+ "finetuning_task": null,
46
+ "forced_bos_token_id": null,
47
+ "forced_eos_token_id": null,
48
+ "head_dim": 128,
49
+ "hidden_act": "silu",
50
+ "hidden_size": 3072,
51
+ "id2label": {
52
+ "0": "LABEL_0",
53
+ "1": "LABEL_1"
54
+ },
55
+ "initializer_range": 0.02,
56
+ "intermediate_size": 8192,
57
+ "is_decoder": false,
58
+ "is_encoder_decoder": false,
59
+ "label2id": {
60
+ "LABEL_0": 0,
61
+ "LABEL_1": 1
62
+ },
63
+ "length_penalty": 1.0,
64
+ "max_length": 20,
65
+ "max_position_embeddings": 131072,
66
+ "min_length": 0,
67
+ "mlp_bias": false,
68
+ "model_type": "llama",
69
+ "my_rope_scaling": {
70
+ "factor": 32.0,
71
+ "high_freq_factor": 4.0,
72
+ "low_freq_factor": 1.0,
73
+ "original_max_position_embeddings": 8192,
74
+ "rope_type": "llama3"
75
+ },
76
+ "no_repeat_ngram_size": 0,
77
+ "num_attention_heads": 24,
78
+ "num_beam_groups": 1,
79
+ "num_beams": 1,
80
+ "num_hidden_layers": 28,
81
+ "num_key_value_heads": 8,
82
+ "num_return_sequences": 1,
83
+ "output_attentions": false,
84
+ "output_hidden_states": false,
85
+ "output_scores": false,
86
+ "pad_token_id": null,
87
+ "prefix": null,
88
+ "pretraining_tp": 1,
89
+ "problem_type": null,
90
+ "pruned_heads": {},
91
+ "remove_invalid_values": false,
92
+ "repetition_penalty": 1.0,
93
+ "return_dict": true,
94
+ "return_dict_in_generate": false,
95
+ "rms_norm_eps": 1e-05,
96
+ "rope_scaling": {
97
+ "factor": 32.0,
98
+ "high_freq_factor": 4.0,
99
+ "low_freq_factor": 1.0,
100
+ "original_max_position_embeddings": 8192,
101
+ "rope_type": "llama3",
102
+ "type": "llama3"
103
+ },
104
+ "rope_theta": 500000.0,
105
+ "sep_token_id": null,
106
+ "suppress_tokens": null,
107
+ "task_specific_params": null,
108
+ "temperature": 1.0,
109
+ "tf_legacy_loss": false,
110
+ "tie_encoder_decoder": false,
111
+ "tie_word_embeddings": true,
112
+ "tokenizer_class": null,
113
+ "top_k": 50,
114
+ "top_p": 1.0,
115
+ "torch_dtype": "bfloat16",
116
+ "torchscript": false,
117
+ "transformers_version": "4.37.2",
118
+ "typical_p": 1.0,
119
+ "use_bfloat16": false,
120
+ "use_cache": false,
121
+ "vocab_size": 128267
122
+ },
123
+ "loss_version": "efficient_v2_cp_head",
124
+ "max_dynamic_patch": 12,
125
+ "min_dynamic_patch": 1,
126
+ "mlp_checkpoint": false,
127
+ "model_type": "eagle_chat",
128
+ "pad2square": false,
129
+ "pre_feature_reduction": false,
130
+ "ps_version": "v2",
131
+ "select_layer": -1,
132
+ "template": "llama3-chat",
133
+ "torch_dtype": "bfloat16",
134
+ "transformers_version": null,
135
+ "use_backbone_lora": 0,
136
+ "use_llm_lora": 0,
137
+ "use_thumbnail": true,
138
+ "vision_config": {
139
+ "_name_or_path": "",
140
+ "add_cross_attention": false,
141
+ "architectures": null,
142
+ "attention_dropout": 0.0,
143
+ "bad_words_ids": null,
144
+ "begin_suppress_tokens": null,
145
+ "bos_token_id": null,
146
+ "chunk_size_feed_forward": 0,
147
+ "cross_attention_hidden_size": null,
148
+ "decoder_start_token_id": null,
149
+ "diversity_penalty": 0.0,
150
+ "do_sample": false,
151
+ "drop_path_rate": 0.1,
152
+ "early_stopping": false,
153
+ "encoder_no_repeat_ngram_size": 0,
154
+ "eos_token_id": null,
155
+ "exponential_decay_length_penalty": null,
156
+ "finetuning_task": null,
157
+ "forced_bos_token_id": null,
158
+ "forced_eos_token_id": null,
159
+ "hidden_act": "gelu_pytorch_tanh",
160
+ "hidden_size": 1152,
161
+ "id2label": {
162
+ "0": "LABEL_0",
163
+ "1": "LABEL_1"
164
+ },
165
+ "image_size": 448,
166
+ "intermediate_size": 4304,
167
+ "is_decoder": false,
168
+ "is_encoder_decoder": false,
169
+ "label2id": {
170
+ "LABEL_0": 0,
171
+ "LABEL_1": 1
172
+ },
173
+ "layer_norm_eps": 1e-06,
174
+ "length_penalty": 1.0,
175
+ "max_length": 20,
176
+ "min_length": 0,
177
+ "model_type": "siglip_vision_model",
178
+ "no_repeat_ngram_size": 0,
179
+ "num_attention_heads": 16,
180
+ "num_beam_groups": 1,
181
+ "num_beams": 1,
182
+ "num_channels": 3,
183
+ "num_hidden_layers": 27,
184
+ "num_return_sequences": 1,
185
+ "output_attentions": false,
186
+ "output_hidden_states": false,
187
+ "output_scores": false,
188
+ "pad_token_id": null,
189
+ "patch_size": 14,
190
+ "prefix": null,
191
+ "problem_type": null,
192
+ "pruned_heads": {},
193
+ "remove_invalid_values": false,
194
+ "repetition_penalty": 1.0,
195
+ "return_dict": true,
196
+ "return_dict_in_generate": false,
197
+ "sep_token_id": null,
198
+ "suppress_tokens": null,
199
+ "task_specific_params": null,
200
+ "temperature": 1.0,
201
+ "tf_legacy_loss": false,
202
+ "tie_encoder_decoder": false,
203
+ "tie_word_embeddings": true,
204
+ "tokenizer_class": null,
205
+ "top_k": 50,
206
+ "top_p": 1.0,
207
+ "torch_dtype": null,
208
+ "torchscript": false,
209
+ "transformers_version": "4.37.2",
210
+ "typical_p": 1.0,
211
+ "use_bfloat16": false
212
+ }
213
+ }
configuration_eagle_chat.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Eagle2
3
+ # Copyright (c) 2025 NVIDIA
4
+ # Licensed under The Apache License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import copy
8
+
9
+ from transformers import AutoConfig
10
+ from transformers.configuration_utils import PretrainedConfig
11
+ from transformers.utils import logging
12
+ from .configuration_siglip import SiglipVisionConfig
13
+ from .configuration_qwen2 import Qwen2Config
14
+ from .configuration_llama import LlamaConfig
15
+ from .configuration_multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModelConfig
16
+ logger = logging.get_logger(__name__)
17
+
18
+
19
+ class Eagle2ChatConfig(PretrainedConfig):
20
+ model_type = 'eagle_chat'
21
+ is_composition = True
22
+
23
+ def __init__(
24
+ self,
25
+ vision_config=None,
26
+ llm_config=None,
27
+ use_backbone_lora=0,
28
+ use_llm_lora=0,
29
+ select_layer=-1,
30
+ force_image_size=None,
31
+ downsample_ratio=0.5,
32
+ template=None,
33
+ dynamic_image_size=False,
34
+ use_thumbnail=False,
35
+ min_dynamic_patch=1,
36
+ max_dynamic_patch=6,
37
+ mlp_checkpoint=True,
38
+ pre_feature_reduction=False,
39
+ keep_aspect_ratio=False,
40
+ **kwargs):
41
+ super().__init__(**kwargs)
42
+
43
+ if vision_config is None:
44
+ vision_config = {}
45
+ logger.info('vision_config is None. Initializing Vision Encoders with default values.')
46
+
47
+ if llm_config is None:
48
+ llm_config = {}
49
+ logger.info('llm_config is None. Initializing the LLM config with default values')
50
+
51
+ if vision_config['model_type'] == 'siglip_vision_model':
52
+ self.vision_config = SiglipVisionConfig(**vision_config)
53
+ elif vision_config['model_type'].startswith("MOB"):
54
+ self.vision_config = MultiBackboneChannelConcatenationVisionModelConfig(**vision_config)
55
+ else:
56
+ raise ValueError('Unsupported model_type: {}'.format(vision_config['model_type']))
57
+
58
+ if llm_config['architectures'][0] == 'LlamaForCausalLM':
59
+ self.llm_config = LlamaConfig(**llm_config)
60
+ elif llm_config['architectures'][0] == 'Qwen2ForCausalLM':
61
+ self.llm_config = Qwen2Config(**llm_config)
62
+ else:
63
+ raise ValueError('Unsupported architecture: {}'.format(llm_config['architectures'][0]))
64
+ self.use_backbone_lora = use_backbone_lora
65
+ self.use_llm_lora = use_llm_lora
66
+ self.select_layer = select_layer
67
+ self.force_image_size = force_image_size
68
+ self.downsample_ratio = downsample_ratio
69
+ self.template = template
70
+ self.dynamic_image_size = dynamic_image_size
71
+ self.use_thumbnail = use_thumbnail
72
+ self.min_dynamic_patch = min_dynamic_patch
73
+ self.max_dynamic_patch = max_dynamic_patch
74
+ self.mlp_checkpoint = mlp_checkpoint
75
+ self.pre_feature_reduction = pre_feature_reduction
76
+ self.keep_aspect_ratio = keep_aspect_ratio
77
+ logger.info(f'keep_aspect_ratio: {self.keep_aspect_ratio}')
78
+ logger.info(f'vision_select_layer: {self.select_layer}')
79
+ logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
80
+ logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}')
81
+
82
+ def to_dict(self):
83
+ """
84
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
85
+
86
+ Returns:
87
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
88
+ """
89
+ output = copy.deepcopy(self.__dict__)
90
+ output['vision_config'] = self.vision_config.to_dict()
91
+ output['llm_config'] = self.llm_config.to_dict()
92
+ output['model_type'] = self.__class__.model_type
93
+ output['use_backbone_lora'] = self.use_backbone_lora
94
+ output['use_llm_lora'] = self.use_llm_lora
95
+ output['select_layer'] = self.select_layer
96
+ output['force_image_size'] = self.force_image_size
97
+ output['downsample_ratio'] = self.downsample_ratio
98
+ output['template'] = self.template
99
+ output['dynamic_image_size'] = self.dynamic_image_size
100
+ output['use_thumbnail'] = self.use_thumbnail
101
+ output['min_dynamic_patch'] = self.min_dynamic_patch
102
+ output['max_dynamic_patch'] = self.max_dynamic_patch
103
+ output['keep_aspect_ratio'] = self.keep_aspect_ratio
104
+
105
+ return output
configuration_llama.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ LLaMA model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+ from typing import Optional
25
+ import inspect
26
+ import copy
27
+ logger = logging.get_logger(__name__)
28
+
29
+ LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
30
+
31
+
32
+ class LlamaConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
35
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
36
+ defaults will yield a similar configuration to that of the LLaMA-7B.
37
+
38
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
+ documentation from [`PretrainedConfig`] for more information.
40
+
41
+
42
+ Args:
43
+ vocab_size (`int`, *optional*, defaults to 32000):
44
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
45
+ `inputs_ids` passed when calling [`LlamaModel`]
46
+ hidden_size (`int`, *optional*, defaults to 4096):
47
+ Dimension of the hidden representations.
48
+ intermediate_size (`int`, *optional*, defaults to 11008):
49
+ Dimension of the MLP representations.
50
+ num_hidden_layers (`int`, *optional*, defaults to 32):
51
+ Number of hidden layers in the Transformer decoder.
52
+ num_attention_heads (`int`, *optional*, defaults to 32):
53
+ Number of attention heads for each attention layer in the Transformer decoder.
54
+ num_key_value_heads (`int`, *optional*):
55
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
56
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
57
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
58
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
59
+ by meanpooling all the original heads within that group. For more details checkout [this
60
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
61
+ `num_attention_heads`.
62
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
63
+ The non-linear activation function (function or string) in the decoder.
64
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
65
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
66
+ Llama 2 up to 4096, CodeLlama up to 16384.
67
+ initializer_range (`float`, *optional*, defaults to 0.02):
68
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
69
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
70
+ The epsilon used by the rms normalization layers.
71
+ use_cache (`bool`, *optional*, defaults to `True`):
72
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
73
+ relevant if `config.is_decoder=True`.
74
+ pad_token_id (`int`, *optional*):
75
+ Padding token id.
76
+ bos_token_id (`int`, *optional*, defaults to 1):
77
+ Beginning of stream token id.
78
+ eos_token_id (`int`, *optional*, defaults to 2):
79
+ End of stream token id.
80
+ pretraining_tp (`int`, *optional*, defaults to 1):
81
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
82
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
83
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
84
+ issue](https://github.com/pytorch/pytorch/issues/76232).
85
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
86
+ Whether to tie weight embeddings
87
+ rope_theta (`float`, *optional*, defaults to 10000.0):
88
+ The base period of the RoPE embeddings.
89
+ rope_scaling (`Dict`, *optional*):
90
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
91
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
92
+ accordingly.
93
+ Expected contents:
94
+ `rope_type` (`str`):
95
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
96
+ 'llama3'], with 'default' being the original RoPE implementation.
97
+ `factor` (`float`, *optional*):
98
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
99
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
100
+ original maximum pre-trained length.
101
+ `original_max_position_embeddings` (`int`, *optional*):
102
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
103
+ pretraining.
104
+ `attention_factor` (`float`, *optional*):
105
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
106
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
107
+ `factor` field to infer the suggested value.
108
+ `beta_fast` (`float`, *optional*):
109
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
110
+ ramp function. If unspecified, it defaults to 32.
111
+ `beta_slow` (`float`, *optional*):
112
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
113
+ ramp function. If unspecified, it defaults to 1.
114
+ `short_factor` (`List[float]`, *optional*):
115
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
116
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
117
+ size divided by the number of attention heads divided by 2
118
+ `long_factor` (`List[float]`, *optional*):
119
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
120
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
121
+ size divided by the number of attention heads divided by 2
122
+ `low_freq_factor` (`float`, *optional*):
123
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
124
+ `high_freq_factor` (`float`, *optional*):
125
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
126
+ attention_bias (`bool`, *optional*, defaults to `False`):
127
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
128
+ attention_dropout (`float`, *optional*, defaults to 0.0):
129
+ The dropout ratio for the attention probabilities.
130
+
131
+ ```python
132
+ >>> from transformers import LlamaModel, LlamaConfig
133
+
134
+ >>> # Initializing a LLaMA llama-7b style configuration
135
+ >>> configuration = LlamaConfig()
136
+
137
+ >>> # Initializing a model from the llama-7b style configuration
138
+ >>> model = LlamaModel(configuration)
139
+
140
+ >>> # Accessing the model configuration
141
+ >>> configuration = model.config
142
+ ```"""
143
+
144
+ model_type = "llama"
145
+ keys_to_ignore_at_inference = ["past_key_values"]
146
+
147
+ def __init__(
148
+ self,
149
+ vocab_size=32000,
150
+ hidden_size=4096,
151
+ intermediate_size=11008,
152
+ num_hidden_layers=32,
153
+ num_attention_heads=32,
154
+ num_key_value_heads=None,
155
+ hidden_act="silu",
156
+ max_position_embeddings=2048,
157
+ initializer_range=0.02,
158
+ rms_norm_eps=1e-6,
159
+ use_cache=True,
160
+ pad_token_id=None,
161
+ bos_token_id=1,
162
+ eos_token_id=2,
163
+ pretraining_tp=1,
164
+ tie_word_embeddings=False,
165
+ rope_theta=10000.0,
166
+ rope_scaling=None,
167
+ my_rope_scaling={
168
+ "factor": 8.0,
169
+ "low_freq_factor": 1.0,
170
+ "high_freq_factor": 4.0,
171
+ "original_max_position_embeddings": 8192,
172
+ "rope_type": "llama3"
173
+ },
174
+ attention_bias=False,
175
+ attention_dropout=0.0,
176
+ attn_implementation='flash_attention_2',
177
+ **kwargs,
178
+ ):
179
+
180
+ self.vocab_size = vocab_size
181
+ self.max_position_embeddings = max_position_embeddings
182
+ self.hidden_size = hidden_size
183
+ self.intermediate_size = intermediate_size
184
+ self.num_hidden_layers = num_hidden_layers
185
+ self.num_attention_heads = num_attention_heads
186
+
187
+ self.attn_implementation = attn_implementation
188
+ if self.attn_implementation is None:
189
+ self.attn_implementation = "flash_attention_2"
190
+
191
+ # for backward compatibility
192
+ if num_key_value_heads is None:
193
+ num_key_value_heads = num_attention_heads
194
+
195
+ self.num_key_value_heads = num_key_value_heads
196
+ self.hidden_act = hidden_act
197
+ self.initializer_range = initializer_range
198
+ self.rms_norm_eps = rms_norm_eps
199
+ self.pretraining_tp = pretraining_tp
200
+ self.use_cache = use_cache
201
+ self.rope_theta = rope_theta
202
+ self.rope_scaling = copy.deepcopy(my_rope_scaling)
203
+ self.my_rope_scaling = my_rope_scaling
204
+ #self._my_rope_scalingvalidation()
205
+ self.attention_bias = attention_bias
206
+ self.attention_dropout = attention_dropout
207
+ if self.rope_scaling is not None and "type" not in self.rope_scaling:
208
+ self.rope_scaling["type"] = self.rope_scaling["rope_type"]
209
+ print('rope_scaling', self.my_rope_scaling, self.rope_scaling)
210
+
211
+ super().__init__(
212
+ pad_token_id=pad_token_id,
213
+ bos_token_id=bos_token_id,
214
+ eos_token_id=eos_token_id,
215
+ tie_word_embeddings=tie_word_embeddings,
216
+ **kwargs,
217
+ )
218
+ print('init done')
219
+
220
+
221
+ def _check_received_keys(
222
+ rope_type: str,
223
+ received_keys: set,
224
+ required_keys: set,
225
+ optional_keys: Optional[set] = None,
226
+ ignore_keys: Optional[set] = None,
227
+ ):
228
+ """Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
229
+ # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
230
+ if "type" in received_keys:
231
+ received_keys -= {"type"}
232
+ required_keys.add("rope_type")
233
+
234
+ # Some models need to store model-specific keys, and we don't want to throw warning at them
235
+ if ignore_keys is not None:
236
+ received_keys -= ignore_keys
237
+
238
+ missing_keys = required_keys - received_keys
239
+ if missing_keys:
240
+ raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
241
+
242
+ if optional_keys is not None:
243
+ unused_keys = received_keys - required_keys - optional_keys
244
+ else:
245
+ unused_keys = received_keys - required_keys
246
+ if unused_keys:
247
+ logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
248
+
249
+
250
+ def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
251
+ rope_scaling = config.rope_scaling
252
+ # from IPython import embed; embed()
253
+ print('rope_scaling2', rope_scaling)
254
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
255
+ required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
256
+ received_keys = set(rope_scaling.keys())
257
+ _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
258
+
259
+ factor = rope_scaling["factor"]
260
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
261
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
262
+
263
+ low_freq_factor = rope_scaling["low_freq_factor"]
264
+ high_freq_factor = rope_scaling["high_freq_factor"]
265
+ if low_freq_factor is None or not isinstance(low_freq_factor, float):
266
+ logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
267
+ if high_freq_factor is None or not isinstance(high_freq_factor, float):
268
+ logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
269
+ if high_freq_factor <= low_freq_factor:
270
+ logger.warning(
271
+ "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
272
+ f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
273
+ )
274
+
275
+ original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
276
+ if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
277
+ logger.warning(
278
+ "`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
279
+ f"{original_max_position_embeddings}"
280
+ )
281
+ if original_max_position_embeddings >= config.max_position_embeddings:
282
+ logger.warning(
283
+ "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
284
+ f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
285
+ )
configuration_multi_backbone_channel_concatentation_model.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Eagle2
3
+ # Copyright (c) 2025 NVIDIA
4
+ # Licensed under The Apache License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import os
8
+ from typing import Union
9
+
10
+ from transformers.configuration_utils import PretrainedConfig
11
+ from transformers.utils import logging
12
+ from .configuration_siglip import SiglipVisionConfig
13
+ logger = logging.get_logger(__name__)
14
+
15
+
16
+ class MultiBackboneChannelConcatenationVisionModelConfig(PretrainedConfig):
17
+ r"""
18
+ This is the configuration class to store the configuration of a [`MultiBackboneChannelConcatenationVisionModelConfig`]. It is used to
19
+ instantiate a vision encoder according to the specified arguments, defining the model architecture.
20
+
21
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
22
+ documentation from [`PretrainedConfig`] for more information.
23
+
24
+ Args:
25
+ vision_path (str): Path to the vision model or its configuration.
26
+ mm_vision_select_layer (int, optional): The layer to select from the vision model
27
+ for multi-modal processing. Defaults to -2.
28
+ grid_size (int, optional): The size of the grid for vision processing. Defaults to 32.
29
+ **kwargs: Additional keyword arguments to be passed to the parent PretrainedConfig.
30
+
31
+ """
32
+
33
+ model_type = 'MOB'
34
+
35
+ def __init__(
36
+ self,
37
+ vision_path,
38
+ mm_vision_select_layer=-2,
39
+ grid_size=32,
40
+ input_image_size=1024,
41
+ hidden_size='lazy_calculation',
42
+ image_size=1024,
43
+ freeze_backbones=None,
44
+ moe_version_type=None,
45
+ delay_load=False,
46
+ convnext_img_size=1024,
47
+ vision_tower_siglip_path=None,
48
+ vision_tower_convnext_path='convnext_xxlarge.clip_laion2b_soup',
49
+ normalize_type='siglip',
50
+ **kwargs,
51
+ ):
52
+ super().__init__(**kwargs)
53
+
54
+ self.normalize_type = normalize_type
55
+ self.vision_path = vision_path
56
+ self.mm_vision_select_layer = mm_vision_select_layer
57
+ self.grid_size = grid_size
58
+ self.input_image_size = input_image_size
59
+ self.image_size = image_size
60
+ self.hidden_size = hidden_size
61
+ self.freeze_backbones = freeze_backbones
62
+ self.moe_version_type = moe_version_type
63
+ self.delay_load = delay_load
64
+ self.convnext_img_size = convnext_img_size
65
+ # other args. to make it compatable with eagle-next
66
+ self.vision_tower_siglip_path = vision_tower_siglip_path
67
+ self.vision_tower_convnext_path = vision_tower_convnext_path
68
+ self.vision_tower = self.vision_path[4:] # remove `MOB:` prefix
69
+
70
+ # asserts
71
+ assert image_size == input_image_size, f"input_image_size ({input_image_size}) != image_size ({image_size})"
72
+
73
+ @classmethod
74
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
75
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
76
+
77
+ if 'vision_config' in config_dict:
78
+ config_dict = config_dict['vision_config']
79
+
80
+ if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
81
+ logger.warning(
82
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
83
+ f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
84
+ )
85
+
86
+ return cls.from_dict(config_dict, **kwargs)
configuration_qwen2.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Qwen2 model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24
+ "Qwen/Qwen2-7B-beta": "https://huggingface.co/Qwen/Qwen2-7B-beta/resolve/main/config.json",
25
+ }
26
+
27
+
28
+ class Qwen2Config(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
31
+ Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
32
+ with the defaults will yield a similar configuration to that of
33
+ Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 151936):
41
+ Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`Qwen2Model`]
43
+ hidden_size (`int`, *optional*, defaults to 4096):
44
+ Dimension of the hidden representations.
45
+ intermediate_size (`int`, *optional*, defaults to 22016):
46
+ Dimension of the MLP representations.
47
+ num_hidden_layers (`int`, *optional*, defaults to 32):
48
+ Number of hidden layers in the Transformer encoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 32):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ num_key_value_heads (`int`, *optional*, defaults to 32):
52
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56
+ by meanpooling all the original heads within that group. For more details checkout [this
57
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
58
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
+ The non-linear activation function (function or string) in the decoder.
60
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
61
+ The maximum sequence length that this model might ever be used with.
62
+ initializer_range (`float`, *optional*, defaults to 0.02):
63
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
65
+ The epsilon used by the rms normalization layers.
66
+ use_cache (`bool`, *optional*, defaults to `True`):
67
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
68
+ relevant if `config.is_decoder=True`.
69
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
70
+ Whether the model's input and output word embeddings should be tied.
71
+ rope_theta (`float`, *optional*, defaults to 10000.0):
72
+ The base period of the RoPE embeddings.
73
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
74
+ Whether to use sliding window attention.
75
+ sliding_window (`int`, *optional*, defaults to 4096):
76
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
77
+ max_window_layers (`int`, *optional*, defaults to 28):
78
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
79
+ attention_dropout (`float`, *optional*, defaults to 0.0):
80
+ The dropout ratio for the attention probabilities.
81
+
82
+ ```python
83
+ >>> from transformers import Qwen2Model, Qwen2Config
84
+
85
+ >>> # Initializing a Qwen2 style configuration
86
+ >>> configuration = Qwen2Config()
87
+
88
+ >>> # Initializing a model from the Qwen2-7B style configuration
89
+ >>> model = Qwen2Model(configuration)
90
+
91
+ >>> # Accessing the model configuration
92
+ >>> configuration = model.config
93
+ ```"""
94
+
95
+ model_type = "qwen2"
96
+ keys_to_ignore_at_inference = ["past_key_values"]
97
+
98
+ def __init__(
99
+ self,
100
+ vocab_size=151936,
101
+ hidden_size=4096,
102
+ intermediate_size=22016,
103
+ num_hidden_layers=32,
104
+ num_attention_heads=32,
105
+ num_key_value_heads=32,
106
+ hidden_act="silu",
107
+ max_position_embeddings=32768,
108
+ initializer_range=0.02,
109
+ rms_norm_eps=1e-6,
110
+ use_cache=True,
111
+ tie_word_embeddings=False,
112
+ rope_theta=10000.0,
113
+ use_sliding_window=False,
114
+ sliding_window=4096,
115
+ max_window_layers=28,
116
+ attention_dropout=0.0,
117
+ attn_implementation='flash_attention_2',
118
+ **kwargs,
119
+ ):
120
+ self.vocab_size = vocab_size
121
+ self.max_position_embeddings = max_position_embeddings
122
+ self.hidden_size = hidden_size
123
+ self.intermediate_size = intermediate_size
124
+ self.num_hidden_layers = num_hidden_layers
125
+ self.num_attention_heads = num_attention_heads
126
+ self.use_sliding_window = use_sliding_window
127
+ self.sliding_window = sliding_window
128
+ self.max_window_layers = max_window_layers
129
+
130
+ self.attn_implementation = attn_implementation
131
+ if self.attn_implementation is None:
132
+ self.attn_implementation = "flash_attention_2"
133
+
134
+ # for backward compatibility
135
+ if num_key_value_heads is None:
136
+ num_key_value_heads = num_attention_heads
137
+
138
+ self.num_key_value_heads = num_key_value_heads
139
+ self.hidden_act = hidden_act
140
+ self.initializer_range = initializer_range
141
+ self.rms_norm_eps = rms_norm_eps
142
+ self.use_cache = use_cache
143
+ self.rope_theta = rope_theta
144
+ self.attention_dropout = attention_dropout
145
+
146
+ super().__init__(
147
+ tie_word_embeddings=tie_word_embeddings,
148
+ **kwargs,
149
+ )
configuration_siglip.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Siglip model configuration"""
16
+
17
+ import os
18
+ from typing import Union
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json",
28
+ }
29
+
30
+
31
+ class SiglipTextConfig(PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
34
+ Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
35
+ configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
36
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
37
+
38
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
+ documentation from [`PretrainedConfig`] for more information.
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 32000):
43
+ Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
44
+ the `inputs_ids` passed when calling [`SiglipModel`].
45
+ hidden_size (`int`, *optional*, defaults to 768):
46
+ Dimensionality of the encoder layers and the pooler layer.
47
+ intermediate_size (`int`, *optional*, defaults to 3072):
48
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
49
+ num_hidden_layers (`int`, *optional*, defaults to 12):
50
+ Number of hidden layers in the Transformer encoder.
51
+ num_attention_heads (`int`, *optional*, defaults to 12):
52
+ Number of attention heads for each attention layer in the Transformer encoder.
53
+ max_position_embeddings (`int`, *optional*, defaults to 64):
54
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
55
+ just in case (e.g., 512 or 1024 or 2048).
56
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
57
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
58
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
59
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
60
+ The epsilon used by the layer normalization layers.
61
+ attention_dropout (`float`, *optional*, defaults to 0.0):
62
+ The dropout ratio for the attention probabilities.
63
+ pad_token_id (`int`, *optional*, defaults to 1):
64
+ The id of the padding token in the vocabulary.
65
+ bos_token_id (`int`, *optional*, defaults to 49406):
66
+ The id of the beginning-of-sequence token in the vocabulary.
67
+ eos_token_id (`int`, *optional*, defaults to 49407):
68
+ The id of the end-of-sequence token in the vocabulary.
69
+
70
+ Example:
71
+
72
+ ```python
73
+ >>> from transformers import SiglipTextConfig, SiglipTextModel
74
+
75
+ >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
76
+ >>> configuration = SiglipTextConfig()
77
+
78
+ >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
79
+ >>> model = SiglipTextModel(configuration)
80
+
81
+ >>> # Accessing the model configuration
82
+ >>> configuration = model.config
83
+ ```"""
84
+
85
+ model_type = "siglip_text_model"
86
+
87
+ def __init__(
88
+ self,
89
+ vocab_size=32000,
90
+ hidden_size=768,
91
+ intermediate_size=3072,
92
+ num_hidden_layers=12,
93
+ num_attention_heads=12,
94
+ max_position_embeddings=64,
95
+ hidden_act="gelu_pytorch_tanh",
96
+ layer_norm_eps=1e-6,
97
+ attention_dropout=0.0,
98
+ # This differs from `CLIPTokenizer`'s default and from openai/siglip
99
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
100
+ pad_token_id=1,
101
+ bos_token_id=49406,
102
+ eos_token_id=49407,
103
+ **kwargs,
104
+ ):
105
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
106
+
107
+ self.vocab_size = vocab_size
108
+ self.hidden_size = hidden_size
109
+ self.intermediate_size = intermediate_size
110
+ self.num_hidden_layers = num_hidden_layers
111
+ self.num_attention_heads = num_attention_heads
112
+ self.max_position_embeddings = max_position_embeddings
113
+ self.layer_norm_eps = layer_norm_eps
114
+ self.hidden_act = hidden_act
115
+ self.attention_dropout = attention_dropout
116
+
117
+ @classmethod
118
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
119
+ cls._set_token_in_kwargs(kwargs)
120
+
121
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
122
+
123
+ # get the text config dict if we are loading from SiglipConfig
124
+ if config_dict.get("model_type") == "siglip":
125
+ config_dict = config_dict["text_config"]
126
+
127
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
128
+ logger.warning(
129
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
130
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
131
+ )
132
+
133
+ return cls.from_dict(config_dict, **kwargs)
134
+
135
+
136
+ class SiglipVisionConfig(PretrainedConfig):
137
+ r"""
138
+ This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
139
+ Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
140
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
141
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
142
+
143
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
144
+ documentation from [`PretrainedConfig`] for more information.
145
+
146
+ Args:
147
+ hidden_size (`int`, *optional*, defaults to 768):
148
+ Dimensionality of the encoder layers and the pooler layer.
149
+ intermediate_size (`int`, *optional*, defaults to 3072):
150
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
151
+ num_hidden_layers (`int`, *optional*, defaults to 12):
152
+ Number of hidden layers in the Transformer encoder.
153
+ num_attention_heads (`int`, *optional*, defaults to 12):
154
+ Number of attention heads for each attention layer in the Transformer encoder.
155
+ num_channels (`int`, *optional*, defaults to 3):
156
+ Number of channels in the input images.
157
+ image_size (`int`, *optional*, defaults to 224):
158
+ The size (resolution) of each image.
159
+ patch_size (`int`, *optional*, defaults to 16):
160
+ The size (resolution) of each patch.
161
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
162
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
163
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
164
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
165
+ The epsilon used by the layer normalization layers.
166
+ attention_dropout (`float`, *optional*, defaults to 0.0):
167
+ The dropout ratio for the attention probabilities.
168
+
169
+ Example:
170
+
171
+ ```python
172
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
173
+
174
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
175
+ >>> configuration = SiglipVisionConfig()
176
+
177
+ >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
178
+ >>> model = SiglipVisionModel(configuration)
179
+
180
+ >>> # Accessing the model configuration
181
+ >>> configuration = model.config
182
+ ```"""
183
+
184
+ model_type = "siglip_vision_model"
185
+
186
+ def __init__(
187
+ self,
188
+ hidden_size=768,
189
+ intermediate_size=3072,
190
+ num_hidden_layers=12,
191
+ num_attention_heads=12,
192
+ num_channels=3,
193
+ image_size=224,
194
+ patch_size=16,
195
+ hidden_act="gelu_pytorch_tanh",
196
+ layer_norm_eps=1e-6,
197
+ attention_dropout=0.0,
198
+ **kwargs,
199
+ ):
200
+ super().__init__(**kwargs)
201
+
202
+ self.hidden_size = hidden_size
203
+ self.intermediate_size = intermediate_size
204
+ self.num_hidden_layers = num_hidden_layers
205
+ self.num_attention_heads = num_attention_heads
206
+ self.num_channels = num_channels
207
+ self.patch_size = patch_size
208
+ self.image_size = image_size
209
+ self.attention_dropout = attention_dropout
210
+ self.layer_norm_eps = layer_norm_eps
211
+ self.hidden_act = hidden_act
212
+
213
+ @classmethod
214
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
215
+ cls._set_token_in_kwargs(kwargs)
216
+
217
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
218
+
219
+ # get the vision config dict if we are loading from SiglipConfig
220
+ if config_dict.get("model_type") == "siglip":
221
+ config_dict = config_dict["vision_config"]
222
+
223
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
224
+ logger.warning(
225
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
226
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
227
+ )
228
+
229
+ return cls.from_dict(config_dict, **kwargs)
230
+
231
+
232
+ class SiglipConfig(PretrainedConfig):
233
+ r"""
234
+ [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
235
+ instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
236
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
237
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
238
+
239
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
240
+ documentation from [`PretrainedConfig`] for more information.
241
+
242
+ Args:
243
+ text_config (`dict`, *optional*):
244
+ Dictionary of configuration options used to initialize [`SiglipTextConfig`].
245
+ vision_config (`dict`, *optional*):
246
+ Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
247
+ kwargs (*optional*):
248
+ Dictionary of keyword arguments.
249
+
250
+ Example:
251
+
252
+ ```python
253
+ >>> from transformers import SiglipConfig, SiglipModel
254
+
255
+ >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
256
+ >>> configuration = SiglipConfig()
257
+
258
+ >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
259
+ >>> model = SiglipModel(configuration)
260
+
261
+ >>> # Accessing the model configuration
262
+ >>> configuration = model.config
263
+
264
+ >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
265
+ >>> from transformers import SiglipTextConfig, SiglipVisionConfig
266
+
267
+ >>> # Initializing a SiglipText and SiglipVision configuration
268
+ >>> config_text = SiglipTextConfig()
269
+ >>> config_vision = SiglipVisionConfig()
270
+
271
+ >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
272
+ ```"""
273
+
274
+ model_type = "siglip"
275
+
276
+ def __init__(self, text_config=None, vision_config=None, **kwargs):
277
+ super().__init__(**kwargs)
278
+
279
+ if text_config is None:
280
+ text_config = {}
281
+ logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
282
+
283
+ if vision_config is None:
284
+ vision_config = {}
285
+ logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
286
+
287
+ self.text_config = SiglipTextConfig(**text_config)
288
+ self.vision_config = SiglipVisionConfig(**vision_config)
289
+
290
+ self.initializer_factor = 1.0
291
+
292
+ @classmethod
293
+ def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs):
294
+ r"""
295
+ Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
296
+ model configuration.
297
+
298
+ Returns:
299
+ [`SiglipConfig`]: An instance of a configuration object
300
+ """
301
+
302
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
conversation.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation prompt templates.
3
+
4
+ We kindly request that you import fastchat instead of copying this file if you wish to use it.
5
+ If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
6
+ """
7
+
8
+ import dataclasses
9
+ from enum import IntEnum, auto
10
+ from typing import Any, Dict, List, Tuple, Union
11
+
12
+
13
+ class SeparatorStyle(IntEnum):
14
+ """Separator styles."""
15
+
16
+ ADD_COLON_SINGLE = auto()
17
+ ADD_COLON_TWO = auto()
18
+ ADD_COLON_SPACE_SINGLE = auto()
19
+ NO_COLON_SINGLE = auto()
20
+ NO_COLON_TWO = auto()
21
+ ADD_NEW_LINE_SINGLE = auto()
22
+ LLAMA2 = auto()
23
+ CHATGLM = auto()
24
+ CHATML = auto()
25
+ CHATINTERN = auto()
26
+ DOLLY = auto()
27
+ RWKV = auto()
28
+ PHOENIX = auto()
29
+ ROBIN = auto()
30
+ FALCON_CHAT = auto()
31
+ CHATGLM3 = auto()
32
+ INTERNVL_ZH = auto()
33
+ MPT = auto()
34
+ LLAMA3 = auto()
35
+
36
+
37
+ @dataclasses.dataclass
38
+ class Conversation:
39
+ """A class that manages prompt templates and keeps all conversation history."""
40
+
41
+ # The name of this template
42
+ name: str
43
+ # The template of the system prompt
44
+ system_template: str = '{system_message}'
45
+ # The system message
46
+ system_message: str = ''
47
+ # The names of two roles
48
+ roles: Tuple[str] = ('USER', 'ASSISTANT')
49
+ # All messages. Each item is (role, message).
50
+ messages: List[List[str]] = ()
51
+ # The number of few shot examples
52
+ offset: int = 0
53
+ # The separator style and configurations
54
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
55
+ sep: str = '\n'
56
+ sep2: str = None
57
+ # Stop criteria (the default one is EOS token)
58
+ stop_str: Union[str, List[str]] = None
59
+ # Stops generation if meeting any token in this list
60
+ stop_token_ids: List[int] = None
61
+
62
+ def get_prompt(self) -> str:
63
+ """Get the prompt for generation."""
64
+ system_prompt = self.system_template.format(system_message=self.system_message)
65
+ if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
66
+ ret = system_prompt + self.sep
67
+ for role, message in self.messages:
68
+ if message:
69
+ ret += role + ': ' + message + self.sep
70
+ else:
71
+ ret += role + ':'
72
+ return ret
73
+ elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
74
+ seps = [self.sep, self.sep2]
75
+ ret = system_prompt + seps[0]
76
+ for i, (role, message) in enumerate(self.messages):
77
+ if message:
78
+ ret += role + ': ' + message + seps[i % 2]
79
+ else:
80
+ ret += role + ':'
81
+ return ret
82
+ elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
83
+ ret = system_prompt + self.sep
84
+ for role, message in self.messages:
85
+ if message:
86
+ ret += role + ': ' + message + self.sep
87
+ else:
88
+ ret += role + ': ' # must be end with a space
89
+ return ret
90
+ elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
91
+ ret = '' if system_prompt == '' else system_prompt + self.sep
92
+ for role, message in self.messages:
93
+ if message:
94
+ ret += role + '\n' + message + self.sep
95
+ else:
96
+ ret += role + '\n'
97
+ return ret
98
+ elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
99
+ ret = system_prompt
100
+ for role, message in self.messages:
101
+ if message:
102
+ ret += role + message + self.sep
103
+ else:
104
+ ret += role
105
+ return ret
106
+ elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
107
+ seps = [self.sep, self.sep2]
108
+ ret = system_prompt
109
+ for i, (role, message) in enumerate(self.messages):
110
+ if message:
111
+ ret += role + message + seps[i % 2]
112
+ else:
113
+ ret += role
114
+ return ret
115
+ elif self.sep_style == SeparatorStyle.RWKV:
116
+ ret = system_prompt
117
+ for i, (role, message) in enumerate(self.messages):
118
+ if message:
119
+ ret += (
120
+ role
121
+ + ': '
122
+ + message.replace('\r\n', '\n').replace('\n\n', '\n')
123
+ )
124
+ ret += '\n\n'
125
+ else:
126
+ ret += role + ':'
127
+ return ret
128
+ elif self.sep_style == SeparatorStyle.LLAMA2:
129
+ seps = [self.sep, self.sep2]
130
+ if self.system_message:
131
+ ret = system_prompt
132
+ else:
133
+ ret = '[INST] '
134
+ for i, (role, message) in enumerate(self.messages):
135
+ tag = self.roles[i % 2]
136
+ if message:
137
+ if i == 0:
138
+ ret += message + ' '
139
+ else:
140
+ ret += tag + ' ' + message + seps[i % 2]
141
+ else:
142
+ ret += tag
143
+ return ret
144
+ elif self.sep_style == SeparatorStyle.CHATGLM:
145
+ # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
146
+ # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
147
+ round_add_n = 1 if self.name == 'chatglm2' else 0
148
+ if system_prompt:
149
+ ret = system_prompt + self.sep
150
+ else:
151
+ ret = ''
152
+
153
+ for i, (role, message) in enumerate(self.messages):
154
+ if i % 2 == 0:
155
+ ret += f'[Round {i//2 + round_add_n}]{self.sep}'
156
+
157
+ if message:
158
+ ret += f'{role}:{message}{self.sep}'
159
+ else:
160
+ ret += f'{role}:'
161
+ return ret
162
+ elif self.sep_style == SeparatorStyle.CHATML:
163
+ ret = '' if system_prompt == '' else system_prompt + self.sep + '\n'
164
+ for role, message in self.messages:
165
+ if message:
166
+ ret += role + '\n' + message + self.sep + '\n'
167
+ else:
168
+ ret += role + '\n'
169
+ return ret
170
+ elif self.sep_style == SeparatorStyle.CHATGLM3:
171
+ ret = ''
172
+ if self.system_message:
173
+ ret += system_prompt
174
+ for role, message in self.messages:
175
+ if message:
176
+ ret += role + '\n' + ' ' + message
177
+ else:
178
+ ret += role
179
+ return ret
180
+ elif self.sep_style == SeparatorStyle.CHATINTERN:
181
+ # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
182
+ seps = [self.sep, self.sep2]
183
+ ret = system_prompt
184
+ for i, (role, message) in enumerate(self.messages):
185
+ # if i % 2 == 0:
186
+ # ret += "<s>"
187
+ if message:
188
+ ret += role + ':' + message + seps[i % 2] + '\n'
189
+ else:
190
+ ret += role + ':'
191
+ return ret
192
+ elif self.sep_style == SeparatorStyle.DOLLY:
193
+ seps = [self.sep, self.sep2]
194
+ ret = system_prompt
195
+ for i, (role, message) in enumerate(self.messages):
196
+ if message:
197
+ ret += role + ':\n' + message + seps[i % 2]
198
+ if i % 2 == 1:
199
+ ret += '\n\n'
200
+ else:
201
+ ret += role + ':\n'
202
+ return ret
203
+ elif self.sep_style == SeparatorStyle.PHOENIX:
204
+ ret = system_prompt
205
+ for role, message in self.messages:
206
+ if message:
207
+ ret += role + ': ' + '<s>' + message + '</s>'
208
+ else:
209
+ ret += role + ': ' + '<s>'
210
+ return ret
211
+ elif self.sep_style == SeparatorStyle.ROBIN:
212
+ ret = system_prompt + self.sep
213
+ for role, message in self.messages:
214
+ if message:
215
+ ret += role + ':\n' + message + self.sep
216
+ else:
217
+ ret += role + ':\n'
218
+ return ret
219
+ elif self.sep_style == SeparatorStyle.FALCON_CHAT:
220
+ ret = ''
221
+ if self.system_message:
222
+ ret += system_prompt + self.sep
223
+ for role, message in self.messages:
224
+ if message:
225
+ ret += role + ': ' + message + self.sep
226
+ else:
227
+ ret += role + ':'
228
+
229
+ return ret
230
+ elif self.sep_style == SeparatorStyle.INTERNVL_ZH:
231
+ seps = [self.sep, self.sep2]
232
+ ret = self.system_message + seps[0]
233
+ for i, (role, message) in enumerate(self.messages):
234
+ if message:
235
+ ret += role + ': ' + message + seps[i % 2]
236
+ else:
237
+ ret += role + ':'
238
+ return ret
239
+ elif self.sep_style == SeparatorStyle.MPT:
240
+ ret = system_prompt + self.sep
241
+ for role, message in self.messages:
242
+ if message:
243
+ if type(message) is tuple:
244
+ message, _, _ = message
245
+ ret += role + message + self.sep
246
+ else:
247
+ ret += role
248
+ return ret
249
+ elif self.sep_style == SeparatorStyle.LLAMA3:
250
+ ret = system_prompt + self.sep
251
+ for role, message in self.messages:
252
+ if message:
253
+ if type(message) is tuple:
254
+ message, _, _ = message
255
+ ret += role + message + self.sep
256
+ else:
257
+ ret += role
258
+ return ret
259
+ else:
260
+ raise ValueError(f'Invalid style: {self.sep_style}')
261
+
262
+ def set_system_message(self, system_message: str):
263
+ """Set the system message."""
264
+ self.system_message = system_message
265
+
266
+ def append_message(self, role: str, message: str):
267
+ """Append a new message."""
268
+ self.messages.append([role, message])
269
+
270
+ def update_last_message(self, message: str):
271
+ """Update the last output.
272
+
273
+ The last message is typically set to be None when constructing the prompt,
274
+ so we need to update it in-place after getting the response from a model.
275
+ """
276
+ self.messages[-1][1] = message
277
+
278
+ def to_gradio_chatbot(self):
279
+ """Convert the conversation to gradio chatbot format."""
280
+ ret = []
281
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
282
+ if i % 2 == 0:
283
+ ret.append([msg, None])
284
+ else:
285
+ ret[-1][-1] = msg
286
+ return ret
287
+
288
+ def to_openai_api_messages(self):
289
+ """Convert the conversation to OpenAI chat completion format."""
290
+ ret = [{'role': 'system', 'content': self.system_message}]
291
+
292
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
293
+ if i % 2 == 0:
294
+ ret.append({'role': 'user', 'content': msg})
295
+ else:
296
+ if msg is not None:
297
+ ret.append({'role': 'assistant', 'content': msg})
298
+ return ret
299
+
300
+ def copy(self):
301
+ return Conversation(
302
+ name=self.name,
303
+ system_template=self.system_template,
304
+ system_message=self.system_message,
305
+ roles=self.roles,
306
+ messages=[[x, y] for x, y in self.messages],
307
+ offset=self.offset,
308
+ sep_style=self.sep_style,
309
+ sep=self.sep,
310
+ sep2=self.sep2,
311
+ stop_str=self.stop_str,
312
+ stop_token_ids=self.stop_token_ids,
313
+ )
314
+
315
+ def dict(self):
316
+ return {
317
+ 'template_name': self.name,
318
+ 'system_message': self.system_message,
319
+ 'roles': self.roles,
320
+ 'messages': self.messages,
321
+ 'offset': self.offset,
322
+ }
323
+
324
+
325
+ # A global registry for all conversation templates
326
+ conv_templates: Dict[str, Conversation] = {}
327
+
328
+
329
+ def register_conv_template(template: Conversation, override: bool = False):
330
+ """Register a new conversation template."""
331
+ if not override:
332
+ assert (
333
+ template.name not in conv_templates
334
+ ), f'{template.name} has been registered.'
335
+
336
+ conv_templates[template.name] = template
337
+
338
+
339
+ def get_conv_template(name: str) -> Conversation:
340
+ """Get a conversation template."""
341
+ return conv_templates[name].copy()
342
+
343
+
344
+ # Note that for inference, using the Hermes-2 and internlm2-chat templates is equivalent.
345
+ register_conv_template(
346
+ Conversation(
347
+ name='Hermes-2',
348
+ system_template='<|im_start|>system\n{system_message}',
349
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
350
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室及多家合作单位联合开发的多模态大语言模型。人工智能实验室致力于原始技术创新,开源开放,共享共创,推动科技进步和产业发展。',
351
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
352
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
353
+ sep_style=SeparatorStyle.MPT,
354
+ sep='<|im_end|>',
355
+ stop_token_ids=[
356
+ 2,
357
+ 6,
358
+ 7,
359
+ 8,
360
+ ],
361
+ stop_str='<|endoftext|>',
362
+ )
363
+ )
364
+
365
+
366
+ register_conv_template(
367
+ Conversation(
368
+ name='internlm2-chat',
369
+ system_template='<|im_start|>system\n{system_message}',
370
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
371
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室及多家合作单位联合开发的多模态大语言模型。人工智能实验室致力于原始技术创新,开源开放,共享共创,推动科技进步和产业发展。',
372
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
373
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
374
+ sep_style=SeparatorStyle.MPT,
375
+ sep='<|im_end|>',
376
+ stop_token_ids=[
377
+ 2,
378
+ 92543,
379
+ 92542
380
+ ]
381
+ )
382
+ )
383
+
384
+
385
+ register_conv_template(
386
+ Conversation(
387
+ name='phi3-chat',
388
+ system_template='<|system|>\n{system_message}',
389
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
390
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室及��家合作单位联合开发的多模态大语言模型。人工智能实验室致力于原始技术创新,开源开放,共享共创,推动科技进步和产业发展。',
391
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
392
+ roles=('<|user|>\n', '<|assistant|>\n'),
393
+ sep_style=SeparatorStyle.MPT,
394
+ sep='<|end|>',
395
+ stop_token_ids=[
396
+ 2,
397
+ 32000,
398
+ 32007
399
+ ]
400
+ )
401
+ )
402
+ register_conv_template(
403
+ Conversation(
404
+ name='llama3-chat',
405
+ system_template='<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}',
406
+ system_message='You are an AI assistant whose name is Eagle-Next.',
407
+ roles=('<|start_header_id|>user<|end_header_id|>\n\n', '<|start_header_id|>assistant<|end_header_id|>\n\n'),
408
+ sep_style=SeparatorStyle.LLAMA3,
409
+ sep='<|eot_id|>',
410
+ stop_token_ids=[
411
+ 128259,
412
+ 128001
413
+ ]
414
+ )
415
+ )
416
+
417
+ # Qwen-chat default template
418
+ # source: https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py#L130
419
+ register_conv_template(
420
+ Conversation(
421
+ name='qwen2-chat',
422
+ system_template='<|im_start|>system\n{system_message}',
423
+ system_message='You are a helpful assistant.',
424
+ roles=('<|im_start|>user', '<|im_start|>assistant'),
425
+ sep_style=SeparatorStyle.CHATML,
426
+ sep='<|im_end|>',
427
+ stop_token_ids=[
428
+ 151643,
429
+ 151644,
430
+ 151645,
431
+ ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>"
432
+ stop_str='<|endoftext|>',
433
+ )
434
+ )
convnext.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ ConvNeXt
2
+
3
+ Papers:
4
+ * `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
5
+ @Article{liu2022convnet,
6
+ author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
7
+ title = {A ConvNet for the 2020s},
8
+ journal = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
9
+ year = {2022},
10
+ }
11
+
12
+ * `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
13
+ @article{Woo2023ConvNeXtV2,
14
+ title={ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders},
15
+ author={Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon and Saining Xie},
16
+ year={2023},
17
+ journal={arXiv preprint arXiv:2301.00808},
18
+ }
19
+
20
+ Original code and weights from:
21
+ * https://github.com/facebookresearch/ConvNeXt, original copyright below
22
+ * https://github.com/facebookresearch/ConvNeXt-V2, original copyright below
23
+
24
+ Model defs atto, femto, pico, nano and _ols / _hnf variants are timm originals.
25
+
26
+ Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
27
+ """
28
+ # ConvNeXt
29
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
30
+ # All rights reserved.
31
+ # This source code is licensed under the MIT license
32
+
33
+ # ConvNeXt-V2
34
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
35
+ # All rights reserved.
36
+ # This source code is licensed under the license found in the
37
+ # LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
38
+ # No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
39
+
40
+ from collections import OrderedDict
41
+ from functools import partial
42
+ from typing import Callable, Optional, Tuple, Union
43
+
44
+ import torch
45
+ import torch.nn as nn
46
+
47
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
48
+ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
49
+ LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
50
+ from timm.layers import NormMlpClassifierHead, ClassifierHead
51
+ from timm.models._builder import build_model_with_cfg
52
+ from timm.models._manipulate import named_apply, checkpoint_seq
53
+ from timm.models._registry import generate_default_cfgs, register_model, register_model_deprecations
54
+
55
+ __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
56
+
57
+
58
+ class Downsample(nn.Module):
59
+
60
+ def __init__(self, in_chs, out_chs, stride=1, dilation=1):
61
+ super().__init__()
62
+ avg_stride = stride if dilation == 1 else 1
63
+ if stride > 1 or dilation > 1:
64
+ avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
65
+ self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
66
+ else:
67
+ self.pool = nn.Identity()
68
+
69
+ if in_chs != out_chs:
70
+ self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
71
+ else:
72
+ self.conv = nn.Identity()
73
+
74
+ def forward(self, x):
75
+ x = self.pool(x)
76
+ x = self.conv(x)
77
+ return x
78
+
79
+
80
+ class ConvNeXtBlock(nn.Module):
81
+ """ ConvNeXt Block
82
+ There are two equivalent implementations:
83
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
84
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
85
+
86
+ Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
87
+ choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
88
+ is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ in_chs: int,
94
+ out_chs: Optional[int] = None,
95
+ kernel_size: int = 7,
96
+ stride: int = 1,
97
+ dilation: Union[int, Tuple[int, int]] = (1, 1),
98
+ mlp_ratio: float = 4,
99
+ conv_mlp: bool = False,
100
+ conv_bias: bool = True,
101
+ use_grn: bool = False,
102
+ ls_init_value: Optional[float] = 1e-6,
103
+ act_layer: Union[str, Callable] = 'gelu',
104
+ norm_layer: Optional[Callable] = None,
105
+ drop_path: float = 0.,
106
+ ):
107
+ """
108
+
109
+ Args:
110
+ in_chs: Block input channels.
111
+ out_chs: Block output channels (same as in_chs if None).
112
+ kernel_size: Depthwise convolution kernel size.
113
+ stride: Stride of depthwise convolution.
114
+ dilation: Tuple specifying input and output dilation of block.
115
+ mlp_ratio: MLP expansion ratio.
116
+ conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
117
+ conv_bias: Apply bias for all convolution (linear) layers.
118
+ use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
119
+ ls_init_value: Layer-scale init values, layer-scale applied if not None.
120
+ act_layer: Activation layer.
121
+ norm_layer: Normalization layer (defaults to LN if not specified).
122
+ drop_path: Stochastic depth probability.
123
+ """
124
+ super().__init__()
125
+ out_chs = out_chs or in_chs
126
+ dilation = to_ntuple(2)(dilation)
127
+ act_layer = get_act_layer(act_layer)
128
+ if not norm_layer:
129
+ norm_layer = LayerNorm2d if conv_mlp else LayerNorm
130
+ mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
131
+ self.use_conv_mlp = conv_mlp
132
+ self.conv_dw = create_conv2d(
133
+ in_chs,
134
+ out_chs,
135
+ kernel_size=kernel_size,
136
+ stride=stride,
137
+ dilation=dilation[0],
138
+ depthwise=True,
139
+ bias=conv_bias,
140
+ )
141
+ self.norm = norm_layer(out_chs)
142
+ self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
143
+ self.weight = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
144
+ if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
145
+ self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
146
+ else:
147
+ self.shortcut = nn.Identity()
148
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
149
+
150
+ def forward(self, x):
151
+ shortcut = x
152
+ x = self.conv_dw(x)
153
+ if self.use_conv_mlp:
154
+ x = self.norm(x)
155
+ x = self.mlp(x)
156
+ else:
157
+ x = x.permute(0, 2, 3, 1)
158
+ x = self.norm(x)
159
+ x = self.mlp(x)
160
+ x = x.permute(0, 3, 1, 2)
161
+ if self.weight is not None:
162
+ x = x.mul(self.weight.reshape(1, -1, 1, 1))
163
+
164
+ x = self.drop_path(x) + self.shortcut(shortcut)
165
+ return x
166
+
167
+
168
+ class ConvNeXtStage(nn.Module):
169
+
170
+ def __init__(
171
+ self,
172
+ in_chs,
173
+ out_chs,
174
+ kernel_size=7,
175
+ stride=2,
176
+ depth=2,
177
+ dilation=(1, 1),
178
+ drop_path_rates=None,
179
+ ls_init_value=1.0,
180
+ conv_mlp=False,
181
+ conv_bias=True,
182
+ use_grn=False,
183
+ act_layer='gelu',
184
+ norm_layer=None,
185
+ norm_layer_cl=None
186
+ ):
187
+ super().__init__()
188
+ self.grad_checkpointing = False
189
+
190
+ if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
191
+ ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
192
+ pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used
193
+ self.downsample = nn.Sequential(
194
+ norm_layer(in_chs),
195
+ create_conv2d(
196
+ in_chs,
197
+ out_chs,
198
+ kernel_size=ds_ks,
199
+ stride=stride,
200
+ dilation=dilation[0],
201
+ padding=pad,
202
+ bias=conv_bias,
203
+ ),
204
+ )
205
+ in_chs = out_chs
206
+ else:
207
+ self.downsample = nn.Identity()
208
+
209
+ drop_path_rates = drop_path_rates or [0.] * depth
210
+ stage_blocks = []
211
+ for i in range(depth):
212
+ stage_blocks.append(ConvNeXtBlock(
213
+ in_chs=in_chs,
214
+ out_chs=out_chs,
215
+ kernel_size=kernel_size,
216
+ dilation=dilation[1],
217
+ drop_path=drop_path_rates[i],
218
+ ls_init_value=ls_init_value,
219
+ conv_mlp=conv_mlp,
220
+ conv_bias=conv_bias,
221
+ use_grn=use_grn,
222
+ act_layer=act_layer,
223
+ norm_layer=norm_layer if conv_mlp else norm_layer_cl,
224
+ ))
225
+ in_chs = out_chs
226
+ self.blocks = nn.Sequential(*stage_blocks)
227
+
228
+ def forward(self, x):
229
+ x = self.downsample(x)
230
+ if self.grad_checkpointing and not torch.jit.is_scripting():
231
+ x = checkpoint_seq(self.blocks, x)
232
+ else:
233
+ x = self.blocks(x)
234
+ return x
235
+
236
+
237
+ class ConvNeXt(nn.Module):
238
+ r""" ConvNeXt
239
+ A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ in_chans: int = 3,
245
+ num_classes: int = 1000,
246
+ global_pool: str = 'avg',
247
+ output_stride: int = 32,
248
+ depths: Tuple[int, ...] = (3, 3, 9, 3),
249
+ dims: Tuple[int, ...] = (96, 192, 384, 768),
250
+ kernel_sizes: Union[int, Tuple[int, ...]] = 7,
251
+ ls_init_value: Optional[float] = 1e-6,
252
+ stem_type: str = 'patch',
253
+ patch_size: int = 4,
254
+ head_init_scale: float = 1.,
255
+ head_norm_first: bool = False,
256
+ head_hidden_size: Optional[int] = None,
257
+ conv_mlp: bool = False,
258
+ conv_bias: bool = True,
259
+ use_grn: bool = False,
260
+ act_layer: Union[str, Callable] = 'gelu',
261
+ norm_layer: Optional[Union[str, Callable]] = None,
262
+ norm_eps: Optional[float] = None,
263
+ drop_rate: float = 0.,
264
+ drop_path_rate: float = 0.,
265
+ ):
266
+ """
267
+ Args:
268
+ in_chans: Number of input image channels.
269
+ num_classes: Number of classes for classification head.
270
+ global_pool: Global pooling type.
271
+ output_stride: Output stride of network, one of (8, 16, 32).
272
+ depths: Number of blocks at each stage.
273
+ dims: Feature dimension at each stage.
274
+ kernel_sizes: Depthwise convolution kernel-sizes for each stage.
275
+ ls_init_value: Init value for Layer Scale, disabled if None.
276
+ stem_type: Type of stem.
277
+ patch_size: Stem patch size for patch stem.
278
+ head_init_scale: Init scaling value for classifier weights and biases.
279
+ head_norm_first: Apply normalization before global pool + head.
280
+ head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
281
+ conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
282
+ conv_bias: Use bias layers w/ all convolutions.
283
+ use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
284
+ act_layer: Activation layer type.
285
+ norm_layer: Normalization layer type.
286
+ drop_rate: Head pre-classifier dropout rate.
287
+ drop_path_rate: Stochastic depth drop rate.
288
+ """
289
+ super().__init__()
290
+ assert output_stride in (8, 16, 32)
291
+ kernel_sizes = to_ntuple(4)(kernel_sizes)
292
+ if norm_layer is None:
293
+ norm_layer = LayerNorm2d
294
+ norm_layer_cl = norm_layer if conv_mlp else LayerNorm
295
+ if norm_eps is not None:
296
+ norm_layer = partial(norm_layer, eps=norm_eps)
297
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
298
+ else:
299
+ assert conv_mlp,\
300
+ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
301
+ norm_layer_cl = norm_layer
302
+ if norm_eps is not None:
303
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
304
+
305
+ self.num_classes = num_classes
306
+ self.drop_rate = drop_rate
307
+ self.feature_info = []
308
+
309
+ assert stem_type in ('patch', 'overlap', 'overlap_tiered')
310
+ if stem_type == 'patch':
311
+ # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
312
+ self.stem = nn.Sequential(
313
+ nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
314
+ norm_layer(dims[0]),
315
+ )
316
+ stem_stride = patch_size
317
+ else:
318
+ mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
319
+ self.stem = nn.Sequential(
320
+ nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
321
+ nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
322
+ norm_layer(dims[0]),
323
+ )
324
+ stem_stride = 4
325
+
326
+ self.stages = nn.Sequential()
327
+ dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
328
+ stages = []
329
+ prev_chs = dims[0]
330
+ curr_stride = stem_stride
331
+ dilation = 1
332
+ # 4 feature resolution stages, each consisting of multiple residual blocks
333
+ for i in range(4):
334
+ stride = 2 if curr_stride == 2 or i > 0 else 1
335
+ if curr_stride >= output_stride and stride > 1:
336
+ dilation *= stride
337
+ stride = 1
338
+ curr_stride *= stride
339
+ first_dilation = 1 if dilation in (1, 2) else 2
340
+ out_chs = dims[i]
341
+ stages.append(ConvNeXtStage(
342
+ prev_chs,
343
+ out_chs,
344
+ kernel_size=kernel_sizes[i],
345
+ stride=stride,
346
+ dilation=(first_dilation, dilation),
347
+ depth=depths[i],
348
+ drop_path_rates=dp_rates[i],
349
+ ls_init_value=ls_init_value,
350
+ conv_mlp=conv_mlp,
351
+ conv_bias=conv_bias,
352
+ use_grn=use_grn,
353
+ act_layer=act_layer,
354
+ norm_layer=norm_layer,
355
+ norm_layer_cl=norm_layer_cl,
356
+ ))
357
+ prev_chs = out_chs
358
+ # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
359
+ self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
360
+ self.stages = nn.Sequential(*stages)
361
+ self.num_features = prev_chs
362
+
363
+ # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
364
+ # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
365
+ if head_norm_first:
366
+ assert not head_hidden_size
367
+ self.norm_pre = norm_layer(self.num_features)
368
+ self.head = ClassifierHead(
369
+ self.num_features,
370
+ num_classes,
371
+ pool_type=global_pool,
372
+ drop_rate=self.drop_rate,
373
+ )
374
+ else:
375
+ self.norm_pre = nn.Identity()
376
+ self.head = NormMlpClassifierHead(
377
+ self.num_features,
378
+ num_classes,
379
+ hidden_size=head_hidden_size,
380
+ pool_type=global_pool,
381
+ drop_rate=self.drop_rate,
382
+ norm_layer=norm_layer,
383
+ act_layer='gelu',
384
+ )
385
+ named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
386
+
387
+ @torch.jit.ignore
388
+ def group_matcher(self, coarse=False):
389
+ return dict(
390
+ stem=r'^stem',
391
+ blocks=r'^stages\.(\d+)' if coarse else [
392
+ (r'^stages\.(\d+)\.downsample', (0,)), # blocks
393
+ (r'^stages\.(\d+)\.blocks\.(\d+)', None),
394
+ (r'^norm_pre', (99999,))
395
+ ]
396
+ )
397
+
398
+ @torch.jit.ignore
399
+ def set_grad_checkpointing(self, enable=True):
400
+ for s in self.stages:
401
+ s.grad_checkpointing = enable
402
+
403
+ @torch.jit.ignore
404
+ def get_classifier(self):
405
+ return self.head.fc
406
+
407
+ def reset_classifier(self, num_classes=0, global_pool=None):
408
+ self.head.reset(num_classes, global_pool)
409
+
410
+ def forward_features(self, x):
411
+ x = self.stem(x)
412
+ x = self.stages(x)
413
+ x = self.norm_pre(x)
414
+ return x
415
+
416
+ def forward_head(self, x, pre_logits: bool = False):
417
+ return self.head(x, pre_logits=True) if pre_logits else self.head(x)
418
+
419
+ def forward(self, x):
420
+ x = self.forward_features(x)
421
+ x = self.forward_head(x)
422
+ return x
423
+
424
+
425
+ def _init_weights(module, name=None, head_init_scale=1.0):
426
+ if isinstance(module, nn.Conv2d):
427
+ trunc_normal_(module.weight, std=.02)
428
+ if module.bias is not None:
429
+ nn.init.zeros_(module.bias)
430
+ elif isinstance(module, nn.Linear):
431
+ trunc_normal_(module.weight, std=.02)
432
+ nn.init.zeros_(module.bias)
433
+ if name and 'head.' in name:
434
+ module.weight.data.mul_(head_init_scale)
435
+ module.bias.data.mul_(head_init_scale)
436
+
437
+
438
+ def checkpoint_filter_fn(state_dict, model):
439
+ """ Remap FB checkpoints -> timm """
440
+ if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
441
+ out_dict={}
442
+ out_dict = {k.replace('gamma', 'weight'): v for k, v in state_dict.items()}
443
+ return out_dict # non-FB checkpoint
444
+ if 'model' in state_dict:
445
+ state_dict = state_dict['model']
446
+
447
+ out_dict = {}
448
+ if 'visual.trunk.stem.0.weight' in state_dict:
449
+ out_dict = {k.replace('visual.trunk.', '').replace('gamma', 'weight'): v for k, v in state_dict.items() if
450
+ k.startswith('visual.trunk.')}
451
+
452
+ if 'visual.head.proj.weight' in state_dict:
453
+ out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
454
+ out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
455
+ elif 'visual.head.mlp.fc1.weight' in state_dict:
456
+ out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight']
457
+ out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias']
458
+ out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight']
459
+ out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0])
460
+ return out_dict
461
+
462
+ import re
463
+ for k, v in state_dict.items():
464
+ k = k.replace('downsample_layers.0.', 'stem.')
465
+ k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
466
+ k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
467
+ k = k.replace('dwconv', 'conv_dw')
468
+ k = k.replace('pwconv', 'mlp.fc')
469
+ if 'grn' in k:
470
+ k = k.replace('grn.beta', 'mlp.grn.bias')
471
+ k = k.replace('grn.gamma', 'mlp.grn.weight')
472
+ v = v.reshape(v.shape[-1])
473
+ k = k.replace('head.', 'head.fc.')
474
+ if k.startswith('norm.'):
475
+ k = k.replace('norm', 'head.norm')
476
+ if v.ndim == 2 and 'head' not in k:
477
+ model_shape = model.state_dict()[k].shape
478
+ v = v.reshape(model_shape)
479
+ k=k.replace('gamma','weight')
480
+ out_dict[k] = v
481
+
482
+ return out_dict
483
+
484
+
485
+ def _create_convnext(variant, pretrained=False, **kwargs):
486
+ if kwargs.get('pretrained_cfg', '') == 'fcmae':
487
+ # NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
488
+ # This is workaround loading with num_classes=0 w/o removing norm-layer.
489
+ kwargs.setdefault('pretrained_strict', False)
490
+
491
+ model = build_model_with_cfg(
492
+ ConvNeXt, variant, pretrained,
493
+ pretrained_filter_fn=checkpoint_filter_fn,
494
+ feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
495
+ **kwargs)
496
+ return model
497
+
498
+
499
+ def _cfg(url='', **kwargs):
500
+ return {
501
+ 'url': url,
502
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
503
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
504
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
505
+ 'first_conv': 'stem.0', 'classifier': 'head.fc',
506
+ **kwargs
507
+ }
508
+
509
+
510
+ def _cfgv2(url='', **kwargs):
511
+ return {
512
+ 'url': url,
513
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
514
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
515
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
516
+ 'first_conv': 'stem.0', 'classifier': 'head.fc',
517
+ 'license': 'cc-by-nc-4.0', 'paper_ids': 'arXiv:2301.00808',
518
+ 'paper_name': 'ConvNeXt-V2: Co-designing and Scaling ConvNets with Masked Autoencoders',
519
+ 'origin_url': 'https://github.com/facebookresearch/ConvNeXt-V2',
520
+ **kwargs
521
+ }
522
+
523
+
524
+ default_cfgs = generate_default_cfgs({
525
+ 'convnext_xxlarge.clip_laion2b_soup_ft_in1k': _cfg(
526
+ hf_hub_id='timm/',
527
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
528
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
529
+
530
+ 'convnext_xxlarge.clip_laion2b_soup_ft_in12k': _cfg(
531
+ hf_hub_id='timm/',
532
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
533
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
534
+ 'convnext_xxlarge.clip_laion2b_soup': _cfg(
535
+ hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup',
536
+ hf_hub_filename='open_clip_pytorch_model.bin',
537
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
538
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
539
+ 'convnext_xxlarge.clip_laion2b_rewind': _cfg(
540
+ hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind',
541
+ hf_hub_filename='open_clip_pytorch_model.bin',
542
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
543
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
544
+ })
545
+
546
+
547
+
548
+ @register_model
549
+ def convnext_xxlarge(pretrained=False, **kwargs) -> ConvNeXt:
550
+ model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5))
551
+ model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **dict(model_args, **kwargs))
552
+ return model
553
+
554
+
555
+
556
+ # register_model_deprecations(__name__, {
557
+ # 'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k',
558
+ # 'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k',
559
+ # 'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k',
560
+ # 'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k',
561
+ # 'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k',
562
+ # 'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384',
563
+ # 'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384',
564
+ # 'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384',
565
+ # 'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384',
566
+ # 'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384',
567
+ # 'convnext_tiny_in22k': 'convnext_tiny.fb_in22k',
568
+ # 'convnext_small_in22k': 'convnext_small.fb_in22k',
569
+ # 'convnext_base_in22k': 'convnext_base.fb_in22k',
570
+ # 'convnext_large_in22k': 'convnext_large.fb_in22k',
571
+ # 'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k',
572
+ # })
convnext_encoder.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ import torch.nn as nn
3
+ from timm import create_model
4
+ from transformers import CLIPImageProcessor
5
+ from .convnext import convnext_xxlarge
6
+ from torch.utils.checkpoint import checkpoint
7
+ import torch
8
+ from torchvision import transforms as T
9
+ from PIL import Image
10
+
11
+
12
+
13
+ cfg={
14
+ "crop_size": 256,
15
+ "do_center_crop": True,
16
+ "do_normalize": True,
17
+ "do_resize": True,
18
+ "feature_extractor_type": "CLIPFeatureExtractor",
19
+ "image_mean": [
20
+ 0.48145466,
21
+ 0.4578275,
22
+ 0.40821073
23
+ ],
24
+ "image_std": [
25
+ 0.26862954,
26
+ 0.26130258,
27
+ 0.27577711
28
+ ],
29
+ "resample": 3,
30
+ "size": 256
31
+ }
32
+
33
+
34
+
35
+ MEAN_SLIP = [0.5, 0.5, 0.5]
36
+ STD_SLIP = [0.5, 0.5, 0.5]
37
+
38
+ MEAN_CLIP = [0.48145466, 0.4578275, 0.40821073]
39
+ STD_CLIP = [0.26862954, 0.26130258, 0.27577711]
40
+
41
+
42
+ a = [s_slip / s_clip for s_slip, s_clip in zip(STD_SLIP, STD_CLIP)]
43
+ b = [(m_slip - m_clip) / s_clip for m_slip, m_clip, s_clip in zip(MEAN_SLIP, MEAN_CLIP, STD_CLIP)]
44
+
45
+
46
+ class SlipToClipTransform:
47
+ def __init__(self, a, b):
48
+ self.a = torch.tensor(a).view(-1, 1, 1)
49
+ self.b = torch.tensor(b).view(-1, 1, 1)
50
+
51
+ def __call__(self, x_slip):
52
+ return x_slip * self.a.to(x_slip.device) + self.b.to(x_slip.device)
53
+ slip_to_clip = SlipToClipTransform(a, b)
54
+
55
+ class ConvNextVisionTower(nn.Module):
56
+ def __init__(self, vision_tower, args, delay_load=False, normalize_type=None):
57
+ super().__init__()
58
+
59
+ self.is_loaded = False
60
+ self.freeze_vision=args.freeze_vision
61
+ self.input_image_size=args.input_image_size
62
+ self.vision_tower_name = vision_tower
63
+ self.name = 'convnext'
64
+ self.select_layer = args.mm_vision_select_layer
65
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
66
+ self.pre_norm = normalize_type
67
+
68
+ print('pre_norm: ', self.pre_norm)
69
+ self.delay_load = delay_load
70
+ self.load_model()
71
+
72
+ def load_model(self):
73
+ if 'xxlarge' in self.vision_tower_name:
74
+ if self.delay_load:
75
+ self.vision_tower = convnext_xxlarge(pretrained=False)
76
+ else:
77
+ self.vision_tower = convnext_xxlarge(self.vision_tower_name)
78
+ setattr(self.vision_tower, 'hidden_size', 3072)
79
+ elif os.path.exists(self.vision_tower_name):
80
+ self.vision_tower = torch.load(self.vision_tower_name)
81
+ else:
82
+ assert False, 'Not implemented'
83
+
84
+
85
+ self.vision_tower = self.vision_tower.to(torch.bfloat16)
86
+
87
+ if self.freeze_vision:
88
+ self.vision_tower.requires_grad_(False)
89
+
90
+ # if self.vision_tower.grad_checkpointing:
91
+ for s in self.vision_tower.stages:
92
+ s.grad_checkpointing = True
93
+
94
+ self.is_loaded = True
95
+
96
+ def feature_select(self, image_forward_outs):
97
+
98
+ if self.select_layer>100:
99
+ image_features = image_forward_outs[-4:]
100
+ else:
101
+ image_features = image_forward_outs[-1]
102
+ return image_features
103
+
104
+ def forward_features(self, x):
105
+ x = self.vision_tower.stem(x)
106
+ image_forward_out=[]
107
+ for blk in self.vision_tower.stages:
108
+ x = blk(x)
109
+ b,c,h,w=x.shape
110
+ image_forward_out.append(x.view(b,c,-1).transpose(1,2))
111
+ return image_forward_out
112
+
113
+ def forward(self, images):
114
+ if self.freeze_vision:
115
+ with torch.no_grad():
116
+ image_features = self._forward_images(images)
117
+ else:
118
+ image_features = self._forward_images(images)
119
+
120
+ return image_features
121
+
122
+ def _forward_images(self, images):
123
+
124
+ if type(images) is list:
125
+ image_features = []
126
+ for image in images:
127
+ if self.pre_norm == 'siglip':
128
+ dtype = image.dtype
129
+ image = slip_to_clip(image.to(torch.float32)).to(dtype)
130
+ image_forward_out = self.forward_features(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
131
+ image_feature = self.feature_select(image_forward_out)
132
+ image_features.append(image_feature)
133
+ else:
134
+ if self.pre_norm == 'siglip':
135
+ dtype = images.dtype
136
+ images = slip_to_clip(images.to(torch.float32)).to(dtype)
137
+ image_forward_outs = self.forward_features(images.to(device=self.device, dtype=self.dtype))
138
+ image_features = self.feature_select(image_forward_outs)
139
+
140
+ return image_features
141
+
142
+ @property
143
+ def dummy_feature(self):
144
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
145
+
146
+ @property
147
+ def dtype(self):
148
+ return next(self.vision_tower.parameters()).dtype
149
+
150
+ @property
151
+ def device(self):
152
+ return next(self.vision_tower.parameters()).device
153
+
154
+ @property
155
+ def config(self):
156
+ assert NotImplementedError
157
+ pass
158
+
159
+ @property
160
+ def num_attention_heads(self):
161
+ # as constant
162
+ return 16
163
+ @property
164
+ def num_layers(self):
165
+ # as constant
166
+ return 4
167
+ @property
168
+ def hidden_size(self):
169
+ return self.vision_tower.hidden_size
170
+
171
+ @property
172
+ def num_patches(self):
173
+ return (self.input_image_size // self.patch_embed.patch_size[0]) ** 2
174
+
175
+
176
+ class ConvNextFPNVisionTower(nn.Module):
177
+ def __init__(self,
178
+ vision_tower,
179
+ args,
180
+ fpn_target_level=1,
181
+ fpn_layer_idx=[1,2,3],
182
+ fpn_input_dim=[768,1536,3072],
183
+ delay_load=False):
184
+
185
+ super().__init__()
186
+
187
+ self.is_loaded = False
188
+ self.vision_tower_name = vision_tower.replace('-fpn', 'fpn')
189
+ self.freeze_vision = getattr(args, "frozen_backbone", True)
190
+ # self.input_image_size = getattr(args, "vision_tower_input_size", 1024)
191
+ self.input_image_size = 1024 # hardcode
192
+ self.select_layer = args.mm_vision_select_layer # no effect
193
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
194
+
195
+ self.need_fpn = True
196
+ self.fpn_layer_idx = fpn_layer_idx # [1, 2, 3] # x8, x16, x32
197
+ self.fpn_input_dim = [768, 1536, 3072]
198
+ self.delay_load = delay_load
199
+ self.load_model()
200
+
201
+ def load_model(self):
202
+ if self.is_loaded:
203
+ return
204
+
205
+ self.image_processor = CLIPImageProcessor(**cfg)
206
+ if 'xxlarge' in self.vision_tower_name:
207
+ self.vision_tower = convnext_xxlarge(self.vision_tower_name)
208
+ setattr(self.vision_tower, 'hidden_size', self.fpn_input_dim)
209
+ # setattr(self.vision_tower, 'hidden_size', 3072)
210
+ else:
211
+ self.vision_tower = convnext_large_mlp(self.vision_tower_name)
212
+ setattr(self.vision_tower, 'hidden_size', 1536)
213
+ if self.freeze_vision:
214
+ self.vision_tower.requires_grad_(False)
215
+
216
+ # if self.vision_tower.grad_checkpointing:
217
+ for s in self.vision_tower.stages:
218
+ s.grad_checkpointing = True
219
+
220
+ if self.input_image_size is not None:
221
+ self.image_processor.size=self.input_image_size
222
+ self.image_processor.crop_size={
223
+ 'height':self.input_image_size,
224
+ 'width': self.input_image_size
225
+ }
226
+
227
+ self.is_loaded = True
228
+
229
+ @torch.no_grad()
230
+ def forward_features(self, x):
231
+ x = self.vision_tower.stem(x)
232
+ image_forward_out=[]
233
+ for blk in self.vision_tower.stages:
234
+ x = blk(x)
235
+ image_forward_out.append(x)
236
+ return image_forward_out
237
+
238
+ @torch.no_grad()
239
+ def forward(self, images):
240
+ if type(images) is list:
241
+ image_features = []
242
+ for image in images:
243
+ image_feature = self.forward_features(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
244
+ image_features.append(image_feature)
245
+ else:
246
+ image_features = self.forward_features(images.to(device=self.device, dtype=self.dtype))
247
+ image_features = [image_features[idx] for idx in self.fpn_layer_idx]
248
+
249
+ return image_features
250
+
251
+ @property
252
+ def dummy_feature(self):
253
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
254
+
255
+ @property
256
+ def dtype(self):
257
+ return next(self.vision_tower.parameters()).dtype
258
+
259
+ @property
260
+ def device(self):
261
+ return next(self.vision_tower.parameters()).device
262
+
263
+ @property
264
+ def config(self):
265
+ assert NotImplementedError
266
+ pass
267
+
268
+ @property
269
+ def num_attention_heads(self):
270
+ # as constant
271
+ return 16
272
+ @property
273
+ def num_layers(self):
274
+ # as constant
275
+ return 4
276
+ @property
277
+ def hidden_size(self):
278
+ return self.vision_tower.hidden_size
279
+
280
+ @property
281
+ def num_patches(self):
282
+ return (cfg['image_size'] // self.patch_embed.patch_size[0]) ** 2
283
+
284
+ if __name__ == '__main__':
285
+ COMBINED_STD = [s_slip / s_clip for s_slip, s_clip in zip(STD_SigLIP, STD_CLIP)]
286
+ COMBINED_MEAN = [(m_slip - m_clip) / s_clip for m_slip, m_clip, s_clip in zip(MEAN_SigLIP, MEAN_CLIP, STD_CLIP)]
287
+
288
+ # 定义合并的归一化变换
289
+ combined_normalize = T.Normalize(mean=COMBINED_MEAN, std=COMBINED_STD)
290
+ x = torch.randn(1, 3, 256, 256).cuda()
291
+ a = normalize_clip(x).to(torch.bfloat16)
292
+ b = normalize_siglip(x).to(torch.bfloat16)
293
+ c = denormalize_siglip(b.to(torch.float32))
294
+ c2 = normalize_clip(c).to(torch.bfloat16)
295
+ c3 = combined_normalize(b)
296
+ print((c-x).abs().max())
297
+ print((c2-a).abs().max())
298
+ print((c3-a).abs().max())
299
+ from IPython import embed
300
+ embed()
301
+ exit()
demo.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ A model worker executes the model.
4
+ """
5
+ from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer, AutoConfig
6
+ import argparse
7
+ import base64
8
+ import json
9
+ import os
10
+ import decord
11
+ import threading
12
+ import time
13
+ from io import BytesIO
14
+ from threading import Thread
15
+ import math
16
+ import requests
17
+ import torch
18
+ import torchvision.transforms as T
19
+ from PIL import Image
20
+ from torchvision.transforms.functional import InterpolationMode
21
+
22
+ import numpy as np
23
+
24
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
25
+ IMAGENET_STD = (0.229, 0.224, 0.225)
26
+
27
+ SIGLIP_MEAN = (0.5, 0.5, 0.5)
28
+ SIGLIP_STD = (0.5, 0.5, 0.5)
29
+
30
+
31
+ def get_seq_frames(total_num_frames, desired_num_frames=-1, stride=-1):
32
+ """
33
+ Calculate the indices of frames to extract from a video.
34
+
35
+ Parameters:
36
+ total_num_frames (int): Total number of frames in the video.
37
+ desired_num_frames (int): Desired number of frames to extract.
38
+
39
+ Returns:
40
+ list: List of indices of frames to extract.
41
+ """
42
+
43
+ assert desired_num_frames > 0 or stride > 0 and not (desired_num_frames > 0 and stride > 0)
44
+
45
+ if stride > 0:
46
+ return list(range(0, total_num_frames, stride))
47
+
48
+ # Calculate the size of each segment from which a frame will be extracted
49
+ seg_size = float(total_num_frames - 1) / desired_num_frames
50
+
51
+ seq = []
52
+ for i in range(desired_num_frames):
53
+ # Calculate the start and end indices of each segment
54
+ start = int(np.round(seg_size * i))
55
+ end = int(np.round(seg_size * (i + 1)))
56
+
57
+ # Append the middle index of the segment to the list
58
+ seq.append((start + end) // 2)
59
+
60
+ return seq
61
+
62
+ def build_video_prompt(meta_list, num_frames, time_position=False):
63
+ # if time_position is True, the frame_timestamp is used.
64
+ # 1. pass time_position, 2. use env TIME_POSITION
65
+ time_position = os.environ.get("TIME_POSITION", time_position)
66
+ prefix = f"This is a video:\n"
67
+ for i in range(num_frames):
68
+ if time_position:
69
+ frame_txt = f"Frame {i+1} sampled at {meta_list[i]:.2f} seconds: <image>\n"
70
+ else:
71
+ frame_txt = f"Frame {i+1}: <image>\n"
72
+ prefix += frame_txt
73
+ return prefix
74
+
75
+ def load_video(video_path, num_frames=64, frame_cache_root=None):
76
+ if isinstance(video_path, str):
77
+ video = decord.VideoReader(video_path)
78
+ elif isinstance(video_path, dict):
79
+ assert False, 'we not support vidoe: "video_path" as input'
80
+ fps = video.get_avg_fps()
81
+ sampled_frames = get_seq_frames(len(video), num_frames)
82
+ samepld_timestamps = [i / fps for i in sampled_frames]
83
+ frames = video.get_batch(sampled_frames).asnumpy()
84
+ images = [Image.fromarray(frame) for frame in frames]
85
+
86
+ return images, build_video_prompt(samepld_timestamps, len(images), time_position=True)
87
+
88
+ def load_image(image):
89
+ if isinstance(image, str) and os.path.exists(image):
90
+ return Image.open(image)
91
+ elif isinstance(image, dict):
92
+ if 'disk_path' in image:
93
+ return Image.open(image['disk_path'])
94
+ elif 'base64' in image:
95
+ return Image.open(BytesIO(base64.b64decode(image['base64'])))
96
+ elif 'url' in image:
97
+ response = requests.get(image['url'])
98
+ return Image.open(BytesIO(response.content))
99
+ elif 'bytes' in image:
100
+ return Image.open(BytesIO(image['bytes']))
101
+ else:
102
+ raise ValueError(f'Invalid image: {image}')
103
+ else:
104
+ raise ValueError(f'Invalid image: {image}')
105
+
106
+ def build_transform(input_size, norm_type='imagenet'):
107
+ if norm_type == 'imagenet':
108
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
109
+ elif norm_type == 'siglip':
110
+ MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
111
+
112
+ transform = T.Compose([
113
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
114
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
115
+ T.ToTensor(),
116
+ T.Normalize(mean=MEAN, std=STD)
117
+ ])
118
+ return transform
119
+
120
+
121
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
122
+ """
123
+ previous version mainly foucs on ratio.
124
+ We also consider area ratio here.
125
+ """
126
+ best_factor = float('-inf')
127
+ best_ratio = (1, 1)
128
+ area = width * height
129
+ for ratio in target_ratios:
130
+ target_aspect_ratio = ratio[0] / ratio[1]
131
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
132
+ area_ratio = (ratio[0]*ratio[1]*image_size*image_size)/ area
133
+ """
134
+ new area > 60% of original image area is enough.
135
+ """
136
+ factor_based_on_area_n_ratio = min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6)* \
137
+ min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)
138
+
139
+ if factor_based_on_area_n_ratio > best_factor:
140
+ best_factor = factor_based_on_area_n_ratio
141
+ best_ratio = ratio
142
+
143
+ return best_ratio
144
+
145
+
146
+ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
147
+ orig_width, orig_height = image.size
148
+ aspect_ratio = orig_width / orig_height
149
+
150
+ # calculate the existing image aspect ratio
151
+ target_ratios = set(
152
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
153
+ i * j <= max_num and i * j >= min_num)
154
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
155
+
156
+ # find the closest aspect ratio to the target
157
+ target_aspect_ratio = find_closest_aspect_ratio(
158
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
159
+
160
+ # calculate the target width and height
161
+ target_width = image_size * target_aspect_ratio[0]
162
+ target_height = image_size * target_aspect_ratio[1]
163
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
164
+
165
+ # resize the image
166
+ resized_img = image.resize((target_width, target_height))
167
+ processed_images = []
168
+ for i in range(blocks):
169
+ box = (
170
+ (i % (target_width // image_size)) * image_size,
171
+ (i // (target_width // image_size)) * image_size,
172
+ ((i % (target_width // image_size)) + 1) * image_size,
173
+ ((i // (target_width // image_size)) + 1) * image_size
174
+ )
175
+ # split the image
176
+ split_img = resized_img.crop(box)
177
+ processed_images.append(split_img)
178
+ assert len(processed_images) == blocks
179
+ if use_thumbnail and len(processed_images) != 1:
180
+ thumbnail_img = image.resize((image_size, image_size))
181
+ processed_images.append(thumbnail_img)
182
+ return processed_images
183
+
184
+ def split_model(model_path, device):
185
+
186
+ device_map = {}
187
+ world_size = torch.cuda.device_count()
188
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
189
+ num_layers = config.llm_config.num_hidden_layers
190
+
191
+ num_layers_per_gpu_ = math.floor(num_layers / (world_size - 1))
192
+ num_layers_per_gpu = [num_layers_per_gpu_] * world_size
193
+ num_layers_per_gpu[device] = num_layers - num_layers_per_gpu_ * (world_size-1)
194
+ layer_cnt = 0
195
+ for i, num_layer in enumerate(num_layers_per_gpu):
196
+ for j in range(num_layer):
197
+ device_map[f'language_model.model.layers.{layer_cnt}'] = i
198
+ layer_cnt += 1
199
+ device_map['vision_model'] = device
200
+ device_map['mlp1'] = device
201
+ device_map['language_model.model.tok_embeddings'] = device
202
+ device_map['language_model.model.embed_tokens'] = device
203
+ device_map['language_model.output'] = device
204
+ device_map['language_model.model.norm'] = device
205
+ device_map['language_model.lm_head'] = device
206
+ device_map['language_model.model.rotary_emb'] = device
207
+ device_map[f'language_model.model.layers.{num_layers - 1}'] = device
208
+ return device_map
209
+
210
+ class ModelWorker:
211
+ def __init__(self, model_path, model_name,
212
+ load_8bit, device):
213
+
214
+ if model_path.endswith('/'):
215
+ model_path = model_path[:-1]
216
+ if model_name is None:
217
+ model_paths = model_path.split('/')
218
+ if model_paths[-1].startswith('checkpoint-'):
219
+ self.model_name = model_paths[-2] + '_' + model_paths[-1]
220
+ else:
221
+ self.model_name = model_paths[-1]
222
+ else:
223
+ self.model_name = model_name
224
+
225
+ print(f'Loading the model {self.model_name}')
226
+
227
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
228
+ tokens_to_keep = ['<box>', '</box>', '<ref>', '</ref>']
229
+ tokenizer.additional_special_tokens = [item for item in tokenizer.additional_special_tokens if item not in tokens_to_keep]
230
+ self.tokenizer = tokenizer
231
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
232
+ model_type = config.vision_config.model_type
233
+ self.device = torch.cuda.current_device()
234
+ if model_type == 'siglip_vision_model':
235
+ self.norm_type = 'siglip'
236
+ elif model_type == 'MOB':
237
+ self.norm_type = 'siglip'
238
+ else:
239
+ self.norm_type = 'imagenet'
240
+
241
+ if any(x in model_path.lower() for x in ['34b']):
242
+ device_map = split_model(model_path, self.device)
243
+ else:
244
+ device_map = None
245
+
246
+ if device_map is not None:
247
+ self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
248
+ low_cpu_mem_usage=True,
249
+ device_map=device_map,
250
+ trust_remote_code=True,
251
+ load_in_8bit=load_8bit).eval()
252
+ else:
253
+ self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
254
+ trust_remote_code=True,
255
+ load_in_8bit=load_8bit).eval()
256
+ if not load_8bit and device_map is None:
257
+ self.model = self.model.to(device)
258
+ self.load_8bit = load_8bit
259
+
260
+ self.model_path = model_path
261
+ self.image_size = self.model.config.force_image_size
262
+ self.context_len = tokenizer.model_max_length
263
+ self.per_tile_len = 256
264
+
265
+ def reload_model(self):
266
+ del self.model
267
+ torch.cuda.empty_cache()
268
+ if self.device == 'auto':
269
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
270
+ # This can make distributed deployment work properly
271
+ self.model = AutoModel.from_pretrained(
272
+ self.model_path,
273
+ load_in_8bit=self.load_8bit,
274
+ torch_dtype=torch.bfloat16,
275
+ device_map=self.device_map,
276
+ trust_remote_code=True).eval()
277
+ else:
278
+ self.model = AutoModel.from_pretrained(
279
+ self.model_path,
280
+ load_in_8bit=self.load_8bit,
281
+ torch_dtype=torch.bfloat16,
282
+ trust_remote_code=True).eval()
283
+ if not self.load_8bit and not self.device == 'auto':
284
+ self.model = self.model.cuda()
285
+
286
+ @torch.inference_mode()
287
+ def generate(self, params):
288
+ system_message = params['prompt'][0]['content']
289
+ send_messages = params['prompt'][1:]
290
+ max_input_tiles = params['max_input_tiles']
291
+ temperature = params['temperature']
292
+ top_p = params['top_p']
293
+ max_new_tokens = params['max_new_tokens']
294
+ repetition_penalty = params['repetition_penalty']
295
+ video_frame_num = params.get('video_frame_num', 64)
296
+ do_sample = True if temperature > 0.0 else False
297
+
298
+ global_image_cnt = 0
299
+ history, pil_images, max_input_tile_list = [], [], []
300
+ for message in send_messages:
301
+ if message['role'] == 'user':
302
+ prefix = ''
303
+ if 'image' in message:
304
+ for image_data in message['image']:
305
+ pil_images.append(load_image(image_data))
306
+ prefix = prefix + f'<image {global_image_cnt + 1}><image>\n'
307
+ global_image_cnt += 1
308
+ max_input_tile_list.append(max_input_tiles)
309
+ if 'video' in message:
310
+ for video_data in message['video']:
311
+ video_frames, tmp_prefix = load_video(video_data, num_frames=video_frame_num)
312
+ pil_images.extend(video_frames)
313
+ prefix = prefix + tmp_prefix
314
+ global_image_cnt += len(video_frames)
315
+ max_input_tile_list.extend([1] * len(video_frames))
316
+ content = prefix + message['content']
317
+ history.append([content, ])
318
+ else:
319
+ history[-1].append(message['content'])
320
+ question, history = history[-1][0], history[:-1]
321
+
322
+ if global_image_cnt == 1:
323
+ question = question.replace('<image 1><image>\n', '<image>\n')
324
+ history = [[item[0].replace('<image 1><image>\n', '<image>\n'), item[1]] for item in history]
325
+
326
+
327
+ try:
328
+ assert len(max_input_tile_list) == len(pil_images), 'The number of max_input_tile_list and pil_images should be the same.'
329
+ except Exception as e:
330
+ from IPython import embed; embed()
331
+ exit()
332
+ print(f'Error: {e}')
333
+ print(f'max_input_tile_list: {max_input_tile_list}, pil_images: {pil_images}')
334
+ # raise e
335
+
336
+ old_system_message = self.model.system_message
337
+ self.model.system_message = system_message
338
+
339
+ transform = build_transform(input_size=self.image_size, norm_type=self.norm_type)
340
+ if len(pil_images) > 0:
341
+ max_input_tiles_limited_by_contect = params['max_input_tiles']
342
+ while True:
343
+ image_tiles = []
344
+ for current_max_input_tiles, pil_image in zip(max_input_tile_list, pil_images):
345
+ if self.model.config.dynamic_image_size:
346
+ tiles = dynamic_preprocess(
347
+ pil_image, image_size=self.image_size, max_num=min(current_max_input_tiles, max_input_tiles_limited_by_contect),
348
+ use_thumbnail=self.model.config.use_thumbnail)
349
+ else:
350
+ tiles = [pil_image]
351
+ image_tiles += tiles
352
+ if (len(image_tiles) * self.per_tile_len < self.context_len):
353
+ break
354
+ else:
355
+ max_input_tiles_limited_by_contect -= 2
356
+
357
+ if max_input_tiles_limited_by_contect < 1:
358
+ break
359
+
360
+ pixel_values = [transform(item) for item in image_tiles]
361
+ pixel_values = torch.stack(pixel_values).to(self.model.device, dtype=torch.bfloat16)
362
+
363
+ else:
364
+ pixel_values = None
365
+
366
+ generation_config = dict(
367
+ num_beams=1,
368
+ max_new_tokens=max_new_tokens,
369
+ do_sample=do_sample,
370
+ temperature=temperature,
371
+ repetition_penalty=repetition_penalty,
372
+ max_length=self.context_len,
373
+ top_p=top_p,
374
+ )
375
+
376
+ response = self.model.chat(
377
+ tokenizer=self.tokenizer,
378
+ pixel_values=pixel_values,
379
+ question=question,
380
+ history=history,
381
+ return_history=False,
382
+ generation_config=generation_config,
383
+ )
384
+ self.model.system_message = old_system_message
385
+ return {'text': response, 'error_code': 0}
386
+
387
+
388
+
389
+
390
+
391
+ if __name__ == '__main__':
392
+ parser = argparse.ArgumentParser()
393
+ parser.add_argument('--model-path', type=str, default='/home/zhidingy/workspace/eagle-next/internvl_chat/work_dirs/release/test/eagle2-commercial_llama3-2_3b_data-v11_gl_16k')
394
+ parser.add_argument('--model-name', type=str, default='Eagle2-1B')
395
+ parser.add_argument('--device', type=str, default='cuda')
396
+ parser.add_argument('--load-8bit', action='store_true')
397
+ args = parser.parse_args()
398
+ print(f'args: {args}')
399
+
400
+ worker = ModelWorker(
401
+ args.model_path,
402
+ args.model_name,
403
+ args.load_8bit,
404
+ args.device)
405
+ prompt = [
406
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
407
+ {'role': 'user', 'content': 'Describe this image in details.',
408
+ 'image':[
409
+ {'url': 'https://www.nvidia.com/content/dam/en-zz/Solutions/about-nvidia/logo-and-brand/[email protected]'}
410
+ ]
411
+ }
412
+ ]
413
+ params = {
414
+ 'prompt': prompt,
415
+ 'max_input_tiles': 24,
416
+ 'temperature': 0.7,
417
+ 'top_p': 1.0,
418
+ 'max_new_tokens': 4096,
419
+ 'repetition_penalty': 1.0,
420
+ }
421
+ print(worker.generate(params))
done.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ done: Mon Feb 10 12:17:46 2025
flash_attention.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+
6
+ try: # v1
7
+ from flash_attn.flash_attn_interface import \
8
+ flash_attn_unpadded_qkvpacked_func
9
+ except: # v2
10
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
11
+
12
+ from flash_attn.bert_padding import pad_input, unpad_input
13
+
14
+
15
+ class FlashAttention(nn.Module):
16
+ """Implement the scaled dot product attention with softmax.
17
+ Arguments
18
+ ---------
19
+ softmax_scale: The temperature to use for the softmax attention.
20
+ (default: 1/sqrt(d_keys) where d_keys is computed at
21
+ runtime)
22
+ attention_dropout: The dropout rate to apply to the attention
23
+ (default: 0.0)
24
+ """
25
+
26
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
27
+ super().__init__()
28
+ self.softmax_scale = softmax_scale
29
+ self.dropout_p = attention_dropout
30
+
31
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
32
+ max_s=None, need_weights=False):
33
+ """Implements the multihead softmax attention.
34
+ Arguments
35
+ ---------
36
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
37
+ if unpadded: (nnz, 3, h, d)
38
+ key_padding_mask: a bool tensor of shape (B, S)
39
+ """
40
+ assert not need_weights
41
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
42
+ assert qkv.is_cuda
43
+
44
+ if cu_seqlens is None:
45
+ batch_size = qkv.shape[0]
46
+ seqlen = qkv.shape[1]
47
+ if key_padding_mask is None:
48
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
49
+ max_s = seqlen
50
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
51
+ device=qkv.device)
52
+ output = flash_attn_unpadded_qkvpacked_func(
53
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
54
+ softmax_scale=self.softmax_scale, causal=causal
55
+ )
56
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
57
+ else:
58
+ nheads = qkv.shape[-2]
59
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
60
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
61
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
62
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
63
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
64
+ softmax_scale=self.softmax_scale, causal=causal
65
+ )
66
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
67
+ indices, batch_size, seqlen),
68
+ 'b s (h d) -> b s h d', h=nheads)
69
+ else:
70
+ assert max_s is not None
71
+ output = flash_attn_unpadded_qkvpacked_func(
72
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
73
+ softmax_scale=self.softmax_scale, causal=causal
74
+ )
75
+
76
+ return output, None
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.37.2"
4
+ }
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ba159dd20093f2e8a75058730df9b72760fee21cc964cee15f5ba12fb72afab
3
+ size 4967370080
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bcf6bc7c9a7a1024c9986403270461b1bc39537953420aa0a25fd8143c4a75f1
3
+ size 3150712792
model.safetensors.index.json ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 8117987200
4
+ },
5
+ "weight_map": {
6
+ "language_model.lm_head.weight": "model-00002-of-00002.safetensors",
7
+ "language_model.model.embed_tokens.weight": "model-00001-of-00002.safetensors",
8
+ "language_model.model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
9
+ "language_model.model.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
10
+ "language_model.model.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
11
+ "language_model.model.layers.0.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
12
+ "language_model.model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
13
+ "language_model.model.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
14
+ "language_model.model.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
15
+ "language_model.model.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
16
+ "language_model.model.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
17
+ "language_model.model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
18
+ "language_model.model.layers.1.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
19
+ "language_model.model.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
20
+ "language_model.model.layers.1.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
21
+ "language_model.model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
22
+ "language_model.model.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
23
+ "language_model.model.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
24
+ "language_model.model.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
25
+ "language_model.model.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
26
+ "language_model.model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
27
+ "language_model.model.layers.10.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
28
+ "language_model.model.layers.10.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
29
+ "language_model.model.layers.10.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
30
+ "language_model.model.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
31
+ "language_model.model.layers.10.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
32
+ "language_model.model.layers.10.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
33
+ "language_model.model.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
34
+ "language_model.model.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
35
+ "language_model.model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
36
+ "language_model.model.layers.11.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
37
+ "language_model.model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
38
+ "language_model.model.layers.11.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
39
+ "language_model.model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
40
+ "language_model.model.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
41
+ "language_model.model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
42
+ "language_model.model.layers.11.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
43
+ "language_model.model.layers.11.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
44
+ "language_model.model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
45
+ "language_model.model.layers.12.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
46
+ "language_model.model.layers.12.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
47
+ "language_model.model.layers.12.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
48
+ "language_model.model.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
49
+ "language_model.model.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
50
+ "language_model.model.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
51
+ "language_model.model.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
52
+ "language_model.model.layers.12.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
53
+ "language_model.model.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors",
54
+ "language_model.model.layers.13.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
55
+ "language_model.model.layers.13.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
56
+ "language_model.model.layers.13.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
57
+ "language_model.model.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
58
+ "language_model.model.layers.13.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
59
+ "language_model.model.layers.13.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
60
+ "language_model.model.layers.13.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
61
+ "language_model.model.layers.13.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
62
+ "language_model.model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
63
+ "language_model.model.layers.14.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
64
+ "language_model.model.layers.14.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
65
+ "language_model.model.layers.14.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
66
+ "language_model.model.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
67
+ "language_model.model.layers.14.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
68
+ "language_model.model.layers.14.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
69
+ "language_model.model.layers.14.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
70
+ "language_model.model.layers.14.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
71
+ "language_model.model.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors",
72
+ "language_model.model.layers.15.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
73
+ "language_model.model.layers.15.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
74
+ "language_model.model.layers.15.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
75
+ "language_model.model.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
76
+ "language_model.model.layers.15.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
77
+ "language_model.model.layers.15.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
78
+ "language_model.model.layers.15.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
79
+ "language_model.model.layers.15.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
80
+ "language_model.model.layers.16.input_layernorm.weight": "model-00002-of-00002.safetensors",
81
+ "language_model.model.layers.16.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
82
+ "language_model.model.layers.16.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
83
+ "language_model.model.layers.16.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
84
+ "language_model.model.layers.16.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
85
+ "language_model.model.layers.16.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
86
+ "language_model.model.layers.16.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
87
+ "language_model.model.layers.16.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
88
+ "language_model.model.layers.16.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
89
+ "language_model.model.layers.17.input_layernorm.weight": "model-00002-of-00002.safetensors",
90
+ "language_model.model.layers.17.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
91
+ "language_model.model.layers.17.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
92
+ "language_model.model.layers.17.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
93
+ "language_model.model.layers.17.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
94
+ "language_model.model.layers.17.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
95
+ "language_model.model.layers.17.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
96
+ "language_model.model.layers.17.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
97
+ "language_model.model.layers.17.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
98
+ "language_model.model.layers.18.input_layernorm.weight": "model-00002-of-00002.safetensors",
99
+ "language_model.model.layers.18.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
100
+ "language_model.model.layers.18.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
101
+ "language_model.model.layers.18.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
102
+ "language_model.model.layers.18.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
103
+ "language_model.model.layers.18.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
104
+ "language_model.model.layers.18.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
105
+ "language_model.model.layers.18.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
106
+ "language_model.model.layers.18.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
107
+ "language_model.model.layers.19.input_layernorm.weight": "model-00002-of-00002.safetensors",
108
+ "language_model.model.layers.19.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
109
+ "language_model.model.layers.19.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
110
+ "language_model.model.layers.19.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
111
+ "language_model.model.layers.19.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
112
+ "language_model.model.layers.19.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
113
+ "language_model.model.layers.19.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
114
+ "language_model.model.layers.19.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
115
+ "language_model.model.layers.19.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
116
+ "language_model.model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
117
+ "language_model.model.layers.2.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
118
+ "language_model.model.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
119
+ "language_model.model.layers.2.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
120
+ "language_model.model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
121
+ "language_model.model.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
122
+ "language_model.model.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
123
+ "language_model.model.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
124
+ "language_model.model.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
125
+ "language_model.model.layers.20.input_layernorm.weight": "model-00002-of-00002.safetensors",
126
+ "language_model.model.layers.20.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
127
+ "language_model.model.layers.20.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
128
+ "language_model.model.layers.20.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
129
+ "language_model.model.layers.20.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
130
+ "language_model.model.layers.20.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
131
+ "language_model.model.layers.20.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
132
+ "language_model.model.layers.20.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
133
+ "language_model.model.layers.20.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
134
+ "language_model.model.layers.21.input_layernorm.weight": "model-00002-of-00002.safetensors",
135
+ "language_model.model.layers.21.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
136
+ "language_model.model.layers.21.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
137
+ "language_model.model.layers.21.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
138
+ "language_model.model.layers.21.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
139
+ "language_model.model.layers.21.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
140
+ "language_model.model.layers.21.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
141
+ "language_model.model.layers.21.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
142
+ "language_model.model.layers.21.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
143
+ "language_model.model.layers.22.input_layernorm.weight": "model-00002-of-00002.safetensors",
144
+ "language_model.model.layers.22.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
145
+ "language_model.model.layers.22.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
146
+ "language_model.model.layers.22.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
147
+ "language_model.model.layers.22.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
148
+ "language_model.model.layers.22.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
149
+ "language_model.model.layers.22.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
150
+ "language_model.model.layers.22.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
151
+ "language_model.model.layers.22.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
152
+ "language_model.model.layers.23.input_layernorm.weight": "model-00002-of-00002.safetensors",
153
+ "language_model.model.layers.23.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
154
+ "language_model.model.layers.23.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
155
+ "language_model.model.layers.23.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
156
+ "language_model.model.layers.23.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
157
+ "language_model.model.layers.23.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
158
+ "language_model.model.layers.23.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
159
+ "language_model.model.layers.23.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
160
+ "language_model.model.layers.23.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
161
+ "language_model.model.layers.24.input_layernorm.weight": "model-00002-of-00002.safetensors",
162
+ "language_model.model.layers.24.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
163
+ "language_model.model.layers.24.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
164
+ "language_model.model.layers.24.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
165
+ "language_model.model.layers.24.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
166
+ "language_model.model.layers.24.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
167
+ "language_model.model.layers.24.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
168
+ "language_model.model.layers.24.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
169
+ "language_model.model.layers.24.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
170
+ "language_model.model.layers.25.input_layernorm.weight": "model-00002-of-00002.safetensors",
171
+ "language_model.model.layers.25.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
172
+ "language_model.model.layers.25.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
173
+ "language_model.model.layers.25.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
174
+ "language_model.model.layers.25.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
175
+ "language_model.model.layers.25.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
176
+ "language_model.model.layers.25.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
177
+ "language_model.model.layers.25.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
178
+ "language_model.model.layers.25.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
179
+ "language_model.model.layers.26.input_layernorm.weight": "model-00002-of-00002.safetensors",
180
+ "language_model.model.layers.26.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
181
+ "language_model.model.layers.26.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
182
+ "language_model.model.layers.26.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
183
+ "language_model.model.layers.26.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
184
+ "language_model.model.layers.26.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
185
+ "language_model.model.layers.26.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
186
+ "language_model.model.layers.26.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
187
+ "language_model.model.layers.26.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
188
+ "language_model.model.layers.27.input_layernorm.weight": "model-00002-of-00002.safetensors",
189
+ "language_model.model.layers.27.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
190
+ "language_model.model.layers.27.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
191
+ "language_model.model.layers.27.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
192
+ "language_model.model.layers.27.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
193
+ "language_model.model.layers.27.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
194
+ "language_model.model.layers.27.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
195
+ "language_model.model.layers.27.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
196
+ "language_model.model.layers.27.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
197
+ "language_model.model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
198
+ "language_model.model.layers.3.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
199
+ "language_model.model.layers.3.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
200
+ "language_model.model.layers.3.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
201
+ "language_model.model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
202
+ "language_model.model.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
203
+ "language_model.model.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
204
+ "language_model.model.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
205
+ "language_model.model.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
206
+ "language_model.model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
207
+ "language_model.model.layers.4.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
208
+ "language_model.model.layers.4.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
209
+ "language_model.model.layers.4.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
210
+ "language_model.model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
211
+ "language_model.model.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
212
+ "language_model.model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
213
+ "language_model.model.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
214
+ "language_model.model.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
215
+ "language_model.model.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors",
216
+ "language_model.model.layers.5.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
217
+ "language_model.model.layers.5.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
218
+ "language_model.model.layers.5.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
219
+ "language_model.model.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
220
+ "language_model.model.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
221
+ "language_model.model.layers.5.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
222
+ "language_model.model.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
223
+ "language_model.model.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
224
+ "language_model.model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
225
+ "language_model.model.layers.6.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
226
+ "language_model.model.layers.6.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
227
+ "language_model.model.layers.6.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
228
+ "language_model.model.layers.6.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
229
+ "language_model.model.layers.6.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
230
+ "language_model.model.layers.6.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
231
+ "language_model.model.layers.6.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
232
+ "language_model.model.layers.6.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
233
+ "language_model.model.layers.7.input_layernorm.weight": "model-00001-of-00002.safetensors",
234
+ "language_model.model.layers.7.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
235
+ "language_model.model.layers.7.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
236
+ "language_model.model.layers.7.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
237
+ "language_model.model.layers.7.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
238
+ "language_model.model.layers.7.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
239
+ "language_model.model.layers.7.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
240
+ "language_model.model.layers.7.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
241
+ "language_model.model.layers.7.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
242
+ "language_model.model.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
243
+ "language_model.model.layers.8.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
244
+ "language_model.model.layers.8.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
245
+ "language_model.model.layers.8.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
246
+ "language_model.model.layers.8.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
247
+ "language_model.model.layers.8.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
248
+ "language_model.model.layers.8.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
249
+ "language_model.model.layers.8.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
250
+ "language_model.model.layers.8.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
251
+ "language_model.model.layers.9.input_layernorm.weight": "model-00001-of-00002.safetensors",
252
+ "language_model.model.layers.9.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
253
+ "language_model.model.layers.9.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
254
+ "language_model.model.layers.9.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
255
+ "language_model.model.layers.9.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
256
+ "language_model.model.layers.9.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
257
+ "language_model.model.layers.9.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
258
+ "language_model.model.layers.9.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
259
+ "language_model.model.layers.9.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
260
+ "language_model.model.norm.weight": "model-00002-of-00002.safetensors",
261
+ "mlp1.0.bias": "model-00002-of-00002.safetensors",
262
+ "mlp1.0.weight": "model-00002-of-00002.safetensors",
263
+ "mlp1.1.bias": "model-00002-of-00002.safetensors",
264
+ "mlp1.1.weight": "model-00002-of-00002.safetensors",
265
+ "mlp1.3.bias": "model-00002-of-00002.safetensors",
266
+ "mlp1.3.weight": "model-00002-of-00002.safetensors",
267
+ "vision_model.vision_model.embeddings.patch_embedding.bias": "model-00001-of-00002.safetensors",
268
+ "vision_model.vision_model.embeddings.patch_embedding.weight": "model-00001-of-00002.safetensors",
269
+ "vision_model.vision_model.embeddings.position_embedding.weight": "model-00001-of-00002.safetensors",
270
+ "vision_model.vision_model.encoder.layers.0.layer_norm1.bias": "model-00001-of-00002.safetensors",
271
+ "vision_model.vision_model.encoder.layers.0.layer_norm1.weight": "model-00001-of-00002.safetensors",
272
+ "vision_model.vision_model.encoder.layers.0.layer_norm2.bias": "model-00001-of-00002.safetensors",
273
+ "vision_model.vision_model.encoder.layers.0.layer_norm2.weight": "model-00001-of-00002.safetensors",
274
+ "vision_model.vision_model.encoder.layers.0.mlp.fc1.bias": "model-00001-of-00002.safetensors",
275
+ "vision_model.vision_model.encoder.layers.0.mlp.fc1.weight": "model-00001-of-00002.safetensors",
276
+ "vision_model.vision_model.encoder.layers.0.mlp.fc2.bias": "model-00001-of-00002.safetensors",
277
+ "vision_model.vision_model.encoder.layers.0.mlp.fc2.weight": "model-00001-of-00002.safetensors",
278
+ "vision_model.vision_model.encoder.layers.0.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
279
+ "vision_model.vision_model.encoder.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
280
+ "vision_model.vision_model.encoder.layers.0.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
281
+ "vision_model.vision_model.encoder.layers.0.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
282
+ "vision_model.vision_model.encoder.layers.0.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
283
+ "vision_model.vision_model.encoder.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
284
+ "vision_model.vision_model.encoder.layers.0.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
285
+ "vision_model.vision_model.encoder.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
286
+ "vision_model.vision_model.encoder.layers.1.layer_norm1.bias": "model-00001-of-00002.safetensors",
287
+ "vision_model.vision_model.encoder.layers.1.layer_norm1.weight": "model-00001-of-00002.safetensors",
288
+ "vision_model.vision_model.encoder.layers.1.layer_norm2.bias": "model-00001-of-00002.safetensors",
289
+ "vision_model.vision_model.encoder.layers.1.layer_norm2.weight": "model-00001-of-00002.safetensors",
290
+ "vision_model.vision_model.encoder.layers.1.mlp.fc1.bias": "model-00001-of-00002.safetensors",
291
+ "vision_model.vision_model.encoder.layers.1.mlp.fc1.weight": "model-00001-of-00002.safetensors",
292
+ "vision_model.vision_model.encoder.layers.1.mlp.fc2.bias": "model-00001-of-00002.safetensors",
293
+ "vision_model.vision_model.encoder.layers.1.mlp.fc2.weight": "model-00001-of-00002.safetensors",
294
+ "vision_model.vision_model.encoder.layers.1.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
295
+ "vision_model.vision_model.encoder.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
296
+ "vision_model.vision_model.encoder.layers.1.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
297
+ "vision_model.vision_model.encoder.layers.1.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
298
+ "vision_model.vision_model.encoder.layers.1.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
299
+ "vision_model.vision_model.encoder.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
300
+ "vision_model.vision_model.encoder.layers.1.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
301
+ "vision_model.vision_model.encoder.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
302
+ "vision_model.vision_model.encoder.layers.10.layer_norm1.bias": "model-00001-of-00002.safetensors",
303
+ "vision_model.vision_model.encoder.layers.10.layer_norm1.weight": "model-00001-of-00002.safetensors",
304
+ "vision_model.vision_model.encoder.layers.10.layer_norm2.bias": "model-00001-of-00002.safetensors",
305
+ "vision_model.vision_model.encoder.layers.10.layer_norm2.weight": "model-00001-of-00002.safetensors",
306
+ "vision_model.vision_model.encoder.layers.10.mlp.fc1.bias": "model-00001-of-00002.safetensors",
307
+ "vision_model.vision_model.encoder.layers.10.mlp.fc1.weight": "model-00001-of-00002.safetensors",
308
+ "vision_model.vision_model.encoder.layers.10.mlp.fc2.bias": "model-00001-of-00002.safetensors",
309
+ "vision_model.vision_model.encoder.layers.10.mlp.fc2.weight": "model-00001-of-00002.safetensors",
310
+ "vision_model.vision_model.encoder.layers.10.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
311
+ "vision_model.vision_model.encoder.layers.10.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
312
+ "vision_model.vision_model.encoder.layers.10.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
313
+ "vision_model.vision_model.encoder.layers.10.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
314
+ "vision_model.vision_model.encoder.layers.10.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
315
+ "vision_model.vision_model.encoder.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
316
+ "vision_model.vision_model.encoder.layers.10.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
317
+ "vision_model.vision_model.encoder.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
318
+ "vision_model.vision_model.encoder.layers.11.layer_norm1.bias": "model-00001-of-00002.safetensors",
319
+ "vision_model.vision_model.encoder.layers.11.layer_norm1.weight": "model-00001-of-00002.safetensors",
320
+ "vision_model.vision_model.encoder.layers.11.layer_norm2.bias": "model-00001-of-00002.safetensors",
321
+ "vision_model.vision_model.encoder.layers.11.layer_norm2.weight": "model-00001-of-00002.safetensors",
322
+ "vision_model.vision_model.encoder.layers.11.mlp.fc1.bias": "model-00001-of-00002.safetensors",
323
+ "vision_model.vision_model.encoder.layers.11.mlp.fc1.weight": "model-00001-of-00002.safetensors",
324
+ "vision_model.vision_model.encoder.layers.11.mlp.fc2.bias": "model-00001-of-00002.safetensors",
325
+ "vision_model.vision_model.encoder.layers.11.mlp.fc2.weight": "model-00001-of-00002.safetensors",
326
+ "vision_model.vision_model.encoder.layers.11.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
327
+ "vision_model.vision_model.encoder.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
328
+ "vision_model.vision_model.encoder.layers.11.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
329
+ "vision_model.vision_model.encoder.layers.11.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
330
+ "vision_model.vision_model.encoder.layers.11.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
331
+ "vision_model.vision_model.encoder.layers.11.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
332
+ "vision_model.vision_model.encoder.layers.11.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
333
+ "vision_model.vision_model.encoder.layers.11.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
334
+ "vision_model.vision_model.encoder.layers.12.layer_norm1.bias": "model-00001-of-00002.safetensors",
335
+ "vision_model.vision_model.encoder.layers.12.layer_norm1.weight": "model-00001-of-00002.safetensors",
336
+ "vision_model.vision_model.encoder.layers.12.layer_norm2.bias": "model-00001-of-00002.safetensors",
337
+ "vision_model.vision_model.encoder.layers.12.layer_norm2.weight": "model-00001-of-00002.safetensors",
338
+ "vision_model.vision_model.encoder.layers.12.mlp.fc1.bias": "model-00001-of-00002.safetensors",
339
+ "vision_model.vision_model.encoder.layers.12.mlp.fc1.weight": "model-00001-of-00002.safetensors",
340
+ "vision_model.vision_model.encoder.layers.12.mlp.fc2.bias": "model-00001-of-00002.safetensors",
341
+ "vision_model.vision_model.encoder.layers.12.mlp.fc2.weight": "model-00001-of-00002.safetensors",
342
+ "vision_model.vision_model.encoder.layers.12.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
343
+ "vision_model.vision_model.encoder.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
344
+ "vision_model.vision_model.encoder.layers.12.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
345
+ "vision_model.vision_model.encoder.layers.12.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
346
+ "vision_model.vision_model.encoder.layers.12.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
347
+ "vision_model.vision_model.encoder.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
348
+ "vision_model.vision_model.encoder.layers.12.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
349
+ "vision_model.vision_model.encoder.layers.12.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
350
+ "vision_model.vision_model.encoder.layers.13.layer_norm1.bias": "model-00001-of-00002.safetensors",
351
+ "vision_model.vision_model.encoder.layers.13.layer_norm1.weight": "model-00001-of-00002.safetensors",
352
+ "vision_model.vision_model.encoder.layers.13.layer_norm2.bias": "model-00001-of-00002.safetensors",
353
+ "vision_model.vision_model.encoder.layers.13.layer_norm2.weight": "model-00001-of-00002.safetensors",
354
+ "vision_model.vision_model.encoder.layers.13.mlp.fc1.bias": "model-00001-of-00002.safetensors",
355
+ "vision_model.vision_model.encoder.layers.13.mlp.fc1.weight": "model-00001-of-00002.safetensors",
356
+ "vision_model.vision_model.encoder.layers.13.mlp.fc2.bias": "model-00001-of-00002.safetensors",
357
+ "vision_model.vision_model.encoder.layers.13.mlp.fc2.weight": "model-00001-of-00002.safetensors",
358
+ "vision_model.vision_model.encoder.layers.13.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
359
+ "vision_model.vision_model.encoder.layers.13.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
360
+ "vision_model.vision_model.encoder.layers.13.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
361
+ "vision_model.vision_model.encoder.layers.13.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
362
+ "vision_model.vision_model.encoder.layers.13.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
363
+ "vision_model.vision_model.encoder.layers.13.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
364
+ "vision_model.vision_model.encoder.layers.13.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
365
+ "vision_model.vision_model.encoder.layers.13.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
366
+ "vision_model.vision_model.encoder.layers.14.layer_norm1.bias": "model-00001-of-00002.safetensors",
367
+ "vision_model.vision_model.encoder.layers.14.layer_norm1.weight": "model-00001-of-00002.safetensors",
368
+ "vision_model.vision_model.encoder.layers.14.layer_norm2.bias": "model-00001-of-00002.safetensors",
369
+ "vision_model.vision_model.encoder.layers.14.layer_norm2.weight": "model-00001-of-00002.safetensors",
370
+ "vision_model.vision_model.encoder.layers.14.mlp.fc1.bias": "model-00001-of-00002.safetensors",
371
+ "vision_model.vision_model.encoder.layers.14.mlp.fc1.weight": "model-00001-of-00002.safetensors",
372
+ "vision_model.vision_model.encoder.layers.14.mlp.fc2.bias": "model-00001-of-00002.safetensors",
373
+ "vision_model.vision_model.encoder.layers.14.mlp.fc2.weight": "model-00001-of-00002.safetensors",
374
+ "vision_model.vision_model.encoder.layers.14.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
375
+ "vision_model.vision_model.encoder.layers.14.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
376
+ "vision_model.vision_model.encoder.layers.14.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
377
+ "vision_model.vision_model.encoder.layers.14.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
378
+ "vision_model.vision_model.encoder.layers.14.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
379
+ "vision_model.vision_model.encoder.layers.14.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
380
+ "vision_model.vision_model.encoder.layers.14.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
381
+ "vision_model.vision_model.encoder.layers.14.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
382
+ "vision_model.vision_model.encoder.layers.15.layer_norm1.bias": "model-00001-of-00002.safetensors",
383
+ "vision_model.vision_model.encoder.layers.15.layer_norm1.weight": "model-00001-of-00002.safetensors",
384
+ "vision_model.vision_model.encoder.layers.15.layer_norm2.bias": "model-00001-of-00002.safetensors",
385
+ "vision_model.vision_model.encoder.layers.15.layer_norm2.weight": "model-00001-of-00002.safetensors",
386
+ "vision_model.vision_model.encoder.layers.15.mlp.fc1.bias": "model-00001-of-00002.safetensors",
387
+ "vision_model.vision_model.encoder.layers.15.mlp.fc1.weight": "model-00001-of-00002.safetensors",
388
+ "vision_model.vision_model.encoder.layers.15.mlp.fc2.bias": "model-00001-of-00002.safetensors",
389
+ "vision_model.vision_model.encoder.layers.15.mlp.fc2.weight": "model-00001-of-00002.safetensors",
390
+ "vision_model.vision_model.encoder.layers.15.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
391
+ "vision_model.vision_model.encoder.layers.15.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
392
+ "vision_model.vision_model.encoder.layers.15.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
393
+ "vision_model.vision_model.encoder.layers.15.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
394
+ "vision_model.vision_model.encoder.layers.15.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
395
+ "vision_model.vision_model.encoder.layers.15.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
396
+ "vision_model.vision_model.encoder.layers.15.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
397
+ "vision_model.vision_model.encoder.layers.15.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
398
+ "vision_model.vision_model.encoder.layers.16.layer_norm1.bias": "model-00001-of-00002.safetensors",
399
+ "vision_model.vision_model.encoder.layers.16.layer_norm1.weight": "model-00001-of-00002.safetensors",
400
+ "vision_model.vision_model.encoder.layers.16.layer_norm2.bias": "model-00001-of-00002.safetensors",
401
+ "vision_model.vision_model.encoder.layers.16.layer_norm2.weight": "model-00001-of-00002.safetensors",
402
+ "vision_model.vision_model.encoder.layers.16.mlp.fc1.bias": "model-00001-of-00002.safetensors",
403
+ "vision_model.vision_model.encoder.layers.16.mlp.fc1.weight": "model-00001-of-00002.safetensors",
404
+ "vision_model.vision_model.encoder.layers.16.mlp.fc2.bias": "model-00001-of-00002.safetensors",
405
+ "vision_model.vision_model.encoder.layers.16.mlp.fc2.weight": "model-00001-of-00002.safetensors",
406
+ "vision_model.vision_model.encoder.layers.16.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
407
+ "vision_model.vision_model.encoder.layers.16.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
408
+ "vision_model.vision_model.encoder.layers.16.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
409
+ "vision_model.vision_model.encoder.layers.16.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
410
+ "vision_model.vision_model.encoder.layers.16.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
411
+ "vision_model.vision_model.encoder.layers.16.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
412
+ "vision_model.vision_model.encoder.layers.16.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
413
+ "vision_model.vision_model.encoder.layers.16.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
414
+ "vision_model.vision_model.encoder.layers.17.layer_norm1.bias": "model-00001-of-00002.safetensors",
415
+ "vision_model.vision_model.encoder.layers.17.layer_norm1.weight": "model-00001-of-00002.safetensors",
416
+ "vision_model.vision_model.encoder.layers.17.layer_norm2.bias": "model-00001-of-00002.safetensors",
417
+ "vision_model.vision_model.encoder.layers.17.layer_norm2.weight": "model-00001-of-00002.safetensors",
418
+ "vision_model.vision_model.encoder.layers.17.mlp.fc1.bias": "model-00001-of-00002.safetensors",
419
+ "vision_model.vision_model.encoder.layers.17.mlp.fc1.weight": "model-00001-of-00002.safetensors",
420
+ "vision_model.vision_model.encoder.layers.17.mlp.fc2.bias": "model-00001-of-00002.safetensors",
421
+ "vision_model.vision_model.encoder.layers.17.mlp.fc2.weight": "model-00001-of-00002.safetensors",
422
+ "vision_model.vision_model.encoder.layers.17.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
423
+ "vision_model.vision_model.encoder.layers.17.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
424
+ "vision_model.vision_model.encoder.layers.17.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
425
+ "vision_model.vision_model.encoder.layers.17.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
426
+ "vision_model.vision_model.encoder.layers.17.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
427
+ "vision_model.vision_model.encoder.layers.17.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
428
+ "vision_model.vision_model.encoder.layers.17.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
429
+ "vision_model.vision_model.encoder.layers.17.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
430
+ "vision_model.vision_model.encoder.layers.18.layer_norm1.bias": "model-00001-of-00002.safetensors",
431
+ "vision_model.vision_model.encoder.layers.18.layer_norm1.weight": "model-00001-of-00002.safetensors",
432
+ "vision_model.vision_model.encoder.layers.18.layer_norm2.bias": "model-00001-of-00002.safetensors",
433
+ "vision_model.vision_model.encoder.layers.18.layer_norm2.weight": "model-00001-of-00002.safetensors",
434
+ "vision_model.vision_model.encoder.layers.18.mlp.fc1.bias": "model-00001-of-00002.safetensors",
435
+ "vision_model.vision_model.encoder.layers.18.mlp.fc1.weight": "model-00001-of-00002.safetensors",
436
+ "vision_model.vision_model.encoder.layers.18.mlp.fc2.bias": "model-00001-of-00002.safetensors",
437
+ "vision_model.vision_model.encoder.layers.18.mlp.fc2.weight": "model-00001-of-00002.safetensors",
438
+ "vision_model.vision_model.encoder.layers.18.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
439
+ "vision_model.vision_model.encoder.layers.18.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
440
+ "vision_model.vision_model.encoder.layers.18.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
441
+ "vision_model.vision_model.encoder.layers.18.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
442
+ "vision_model.vision_model.encoder.layers.18.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
443
+ "vision_model.vision_model.encoder.layers.18.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
444
+ "vision_model.vision_model.encoder.layers.18.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
445
+ "vision_model.vision_model.encoder.layers.18.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
446
+ "vision_model.vision_model.encoder.layers.19.layer_norm1.bias": "model-00001-of-00002.safetensors",
447
+ "vision_model.vision_model.encoder.layers.19.layer_norm1.weight": "model-00001-of-00002.safetensors",
448
+ "vision_model.vision_model.encoder.layers.19.layer_norm2.bias": "model-00001-of-00002.safetensors",
449
+ "vision_model.vision_model.encoder.layers.19.layer_norm2.weight": "model-00001-of-00002.safetensors",
450
+ "vision_model.vision_model.encoder.layers.19.mlp.fc1.bias": "model-00001-of-00002.safetensors",
451
+ "vision_model.vision_model.encoder.layers.19.mlp.fc1.weight": "model-00001-of-00002.safetensors",
452
+ "vision_model.vision_model.encoder.layers.19.mlp.fc2.bias": "model-00001-of-00002.safetensors",
453
+ "vision_model.vision_model.encoder.layers.19.mlp.fc2.weight": "model-00001-of-00002.safetensors",
454
+ "vision_model.vision_model.encoder.layers.19.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
455
+ "vision_model.vision_model.encoder.layers.19.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
456
+ "vision_model.vision_model.encoder.layers.19.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
457
+ "vision_model.vision_model.encoder.layers.19.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
458
+ "vision_model.vision_model.encoder.layers.19.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
459
+ "vision_model.vision_model.encoder.layers.19.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
460
+ "vision_model.vision_model.encoder.layers.19.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
461
+ "vision_model.vision_model.encoder.layers.19.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
462
+ "vision_model.vision_model.encoder.layers.2.layer_norm1.bias": "model-00001-of-00002.safetensors",
463
+ "vision_model.vision_model.encoder.layers.2.layer_norm1.weight": "model-00001-of-00002.safetensors",
464
+ "vision_model.vision_model.encoder.layers.2.layer_norm2.bias": "model-00001-of-00002.safetensors",
465
+ "vision_model.vision_model.encoder.layers.2.layer_norm2.weight": "model-00001-of-00002.safetensors",
466
+ "vision_model.vision_model.encoder.layers.2.mlp.fc1.bias": "model-00001-of-00002.safetensors",
467
+ "vision_model.vision_model.encoder.layers.2.mlp.fc1.weight": "model-00001-of-00002.safetensors",
468
+ "vision_model.vision_model.encoder.layers.2.mlp.fc2.bias": "model-00001-of-00002.safetensors",
469
+ "vision_model.vision_model.encoder.layers.2.mlp.fc2.weight": "model-00001-of-00002.safetensors",
470
+ "vision_model.vision_model.encoder.layers.2.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
471
+ "vision_model.vision_model.encoder.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
472
+ "vision_model.vision_model.encoder.layers.2.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
473
+ "vision_model.vision_model.encoder.layers.2.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
474
+ "vision_model.vision_model.encoder.layers.2.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
475
+ "vision_model.vision_model.encoder.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
476
+ "vision_model.vision_model.encoder.layers.2.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
477
+ "vision_model.vision_model.encoder.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
478
+ "vision_model.vision_model.encoder.layers.20.layer_norm1.bias": "model-00001-of-00002.safetensors",
479
+ "vision_model.vision_model.encoder.layers.20.layer_norm1.weight": "model-00001-of-00002.safetensors",
480
+ "vision_model.vision_model.encoder.layers.20.layer_norm2.bias": "model-00001-of-00002.safetensors",
481
+ "vision_model.vision_model.encoder.layers.20.layer_norm2.weight": "model-00001-of-00002.safetensors",
482
+ "vision_model.vision_model.encoder.layers.20.mlp.fc1.bias": "model-00001-of-00002.safetensors",
483
+ "vision_model.vision_model.encoder.layers.20.mlp.fc1.weight": "model-00001-of-00002.safetensors",
484
+ "vision_model.vision_model.encoder.layers.20.mlp.fc2.bias": "model-00001-of-00002.safetensors",
485
+ "vision_model.vision_model.encoder.layers.20.mlp.fc2.weight": "model-00001-of-00002.safetensors",
486
+ "vision_model.vision_model.encoder.layers.20.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
487
+ "vision_model.vision_model.encoder.layers.20.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
488
+ "vision_model.vision_model.encoder.layers.20.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
489
+ "vision_model.vision_model.encoder.layers.20.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
490
+ "vision_model.vision_model.encoder.layers.20.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
491
+ "vision_model.vision_model.encoder.layers.20.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
492
+ "vision_model.vision_model.encoder.layers.20.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
493
+ "vision_model.vision_model.encoder.layers.20.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
494
+ "vision_model.vision_model.encoder.layers.21.layer_norm1.bias": "model-00001-of-00002.safetensors",
495
+ "vision_model.vision_model.encoder.layers.21.layer_norm1.weight": "model-00001-of-00002.safetensors",
496
+ "vision_model.vision_model.encoder.layers.21.layer_norm2.bias": "model-00001-of-00002.safetensors",
497
+ "vision_model.vision_model.encoder.layers.21.layer_norm2.weight": "model-00001-of-00002.safetensors",
498
+ "vision_model.vision_model.encoder.layers.21.mlp.fc1.bias": "model-00001-of-00002.safetensors",
499
+ "vision_model.vision_model.encoder.layers.21.mlp.fc1.weight": "model-00001-of-00002.safetensors",
500
+ "vision_model.vision_model.encoder.layers.21.mlp.fc2.bias": "model-00001-of-00002.safetensors",
501
+ "vision_model.vision_model.encoder.layers.21.mlp.fc2.weight": "model-00001-of-00002.safetensors",
502
+ "vision_model.vision_model.encoder.layers.21.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
503
+ "vision_model.vision_model.encoder.layers.21.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
504
+ "vision_model.vision_model.encoder.layers.21.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
505
+ "vision_model.vision_model.encoder.layers.21.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
506
+ "vision_model.vision_model.encoder.layers.21.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
507
+ "vision_model.vision_model.encoder.layers.21.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
508
+ "vision_model.vision_model.encoder.layers.21.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
509
+ "vision_model.vision_model.encoder.layers.21.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
510
+ "vision_model.vision_model.encoder.layers.22.layer_norm1.bias": "model-00001-of-00002.safetensors",
511
+ "vision_model.vision_model.encoder.layers.22.layer_norm1.weight": "model-00001-of-00002.safetensors",
512
+ "vision_model.vision_model.encoder.layers.22.layer_norm2.bias": "model-00001-of-00002.safetensors",
513
+ "vision_model.vision_model.encoder.layers.22.layer_norm2.weight": "model-00001-of-00002.safetensors",
514
+ "vision_model.vision_model.encoder.layers.22.mlp.fc1.bias": "model-00001-of-00002.safetensors",
515
+ "vision_model.vision_model.encoder.layers.22.mlp.fc1.weight": "model-00001-of-00002.safetensors",
516
+ "vision_model.vision_model.encoder.layers.22.mlp.fc2.bias": "model-00001-of-00002.safetensors",
517
+ "vision_model.vision_model.encoder.layers.22.mlp.fc2.weight": "model-00001-of-00002.safetensors",
518
+ "vision_model.vision_model.encoder.layers.22.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
519
+ "vision_model.vision_model.encoder.layers.22.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
520
+ "vision_model.vision_model.encoder.layers.22.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
521
+ "vision_model.vision_model.encoder.layers.22.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
522
+ "vision_model.vision_model.encoder.layers.22.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
523
+ "vision_model.vision_model.encoder.layers.22.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
524
+ "vision_model.vision_model.encoder.layers.22.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
525
+ "vision_model.vision_model.encoder.layers.22.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
526
+ "vision_model.vision_model.encoder.layers.23.layer_norm1.bias": "model-00001-of-00002.safetensors",
527
+ "vision_model.vision_model.encoder.layers.23.layer_norm1.weight": "model-00001-of-00002.safetensors",
528
+ "vision_model.vision_model.encoder.layers.23.layer_norm2.bias": "model-00001-of-00002.safetensors",
529
+ "vision_model.vision_model.encoder.layers.23.layer_norm2.weight": "model-00001-of-00002.safetensors",
530
+ "vision_model.vision_model.encoder.layers.23.mlp.fc1.bias": "model-00001-of-00002.safetensors",
531
+ "vision_model.vision_model.encoder.layers.23.mlp.fc1.weight": "model-00001-of-00002.safetensors",
532
+ "vision_model.vision_model.encoder.layers.23.mlp.fc2.bias": "model-00001-of-00002.safetensors",
533
+ "vision_model.vision_model.encoder.layers.23.mlp.fc2.weight": "model-00001-of-00002.safetensors",
534
+ "vision_model.vision_model.encoder.layers.23.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
535
+ "vision_model.vision_model.encoder.layers.23.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
536
+ "vision_model.vision_model.encoder.layers.23.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
537
+ "vision_model.vision_model.encoder.layers.23.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
538
+ "vision_model.vision_model.encoder.layers.23.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
539
+ "vision_model.vision_model.encoder.layers.23.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
540
+ "vision_model.vision_model.encoder.layers.23.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
541
+ "vision_model.vision_model.encoder.layers.23.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
542
+ "vision_model.vision_model.encoder.layers.24.layer_norm1.bias": "model-00001-of-00002.safetensors",
543
+ "vision_model.vision_model.encoder.layers.24.layer_norm1.weight": "model-00001-of-00002.safetensors",
544
+ "vision_model.vision_model.encoder.layers.24.layer_norm2.bias": "model-00001-of-00002.safetensors",
545
+ "vision_model.vision_model.encoder.layers.24.layer_norm2.weight": "model-00001-of-00002.safetensors",
546
+ "vision_model.vision_model.encoder.layers.24.mlp.fc1.bias": "model-00001-of-00002.safetensors",
547
+ "vision_model.vision_model.encoder.layers.24.mlp.fc1.weight": "model-00001-of-00002.safetensors",
548
+ "vision_model.vision_model.encoder.layers.24.mlp.fc2.bias": "model-00001-of-00002.safetensors",
549
+ "vision_model.vision_model.encoder.layers.24.mlp.fc2.weight": "model-00001-of-00002.safetensors",
550
+ "vision_model.vision_model.encoder.layers.24.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
551
+ "vision_model.vision_model.encoder.layers.24.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
552
+ "vision_model.vision_model.encoder.layers.24.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
553
+ "vision_model.vision_model.encoder.layers.24.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
554
+ "vision_model.vision_model.encoder.layers.24.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
555
+ "vision_model.vision_model.encoder.layers.24.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
556
+ "vision_model.vision_model.encoder.layers.24.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
557
+ "vision_model.vision_model.encoder.layers.24.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
558
+ "vision_model.vision_model.encoder.layers.25.layer_norm1.bias": "model-00001-of-00002.safetensors",
559
+ "vision_model.vision_model.encoder.layers.25.layer_norm1.weight": "model-00001-of-00002.safetensors",
560
+ "vision_model.vision_model.encoder.layers.25.layer_norm2.bias": "model-00001-of-00002.safetensors",
561
+ "vision_model.vision_model.encoder.layers.25.layer_norm2.weight": "model-00001-of-00002.safetensors",
562
+ "vision_model.vision_model.encoder.layers.25.mlp.fc1.bias": "model-00001-of-00002.safetensors",
563
+ "vision_model.vision_model.encoder.layers.25.mlp.fc1.weight": "model-00001-of-00002.safetensors",
564
+ "vision_model.vision_model.encoder.layers.25.mlp.fc2.bias": "model-00001-of-00002.safetensors",
565
+ "vision_model.vision_model.encoder.layers.25.mlp.fc2.weight": "model-00001-of-00002.safetensors",
566
+ "vision_model.vision_model.encoder.layers.25.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
567
+ "vision_model.vision_model.encoder.layers.25.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
568
+ "vision_model.vision_model.encoder.layers.25.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
569
+ "vision_model.vision_model.encoder.layers.25.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
570
+ "vision_model.vision_model.encoder.layers.25.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
571
+ "vision_model.vision_model.encoder.layers.25.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
572
+ "vision_model.vision_model.encoder.layers.25.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
573
+ "vision_model.vision_model.encoder.layers.25.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
574
+ "vision_model.vision_model.encoder.layers.26.layer_norm1.bias": "model-00001-of-00002.safetensors",
575
+ "vision_model.vision_model.encoder.layers.26.layer_norm1.weight": "model-00001-of-00002.safetensors",
576
+ "vision_model.vision_model.encoder.layers.26.layer_norm2.bias": "model-00001-of-00002.safetensors",
577
+ "vision_model.vision_model.encoder.layers.26.layer_norm2.weight": "model-00001-of-00002.safetensors",
578
+ "vision_model.vision_model.encoder.layers.26.mlp.fc1.bias": "model-00001-of-00002.safetensors",
579
+ "vision_model.vision_model.encoder.layers.26.mlp.fc1.weight": "model-00001-of-00002.safetensors",
580
+ "vision_model.vision_model.encoder.layers.26.mlp.fc2.bias": "model-00001-of-00002.safetensors",
581
+ "vision_model.vision_model.encoder.layers.26.mlp.fc2.weight": "model-00001-of-00002.safetensors",
582
+ "vision_model.vision_model.encoder.layers.26.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
583
+ "vision_model.vision_model.encoder.layers.26.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
584
+ "vision_model.vision_model.encoder.layers.26.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
585
+ "vision_model.vision_model.encoder.layers.26.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
586
+ "vision_model.vision_model.encoder.layers.26.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
587
+ "vision_model.vision_model.encoder.layers.26.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
588
+ "vision_model.vision_model.encoder.layers.26.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
589
+ "vision_model.vision_model.encoder.layers.26.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
590
+ "vision_model.vision_model.encoder.layers.3.layer_norm1.bias": "model-00001-of-00002.safetensors",
591
+ "vision_model.vision_model.encoder.layers.3.layer_norm1.weight": "model-00001-of-00002.safetensors",
592
+ "vision_model.vision_model.encoder.layers.3.layer_norm2.bias": "model-00001-of-00002.safetensors",
593
+ "vision_model.vision_model.encoder.layers.3.layer_norm2.weight": "model-00001-of-00002.safetensors",
594
+ "vision_model.vision_model.encoder.layers.3.mlp.fc1.bias": "model-00001-of-00002.safetensors",
595
+ "vision_model.vision_model.encoder.layers.3.mlp.fc1.weight": "model-00001-of-00002.safetensors",
596
+ "vision_model.vision_model.encoder.layers.3.mlp.fc2.bias": "model-00001-of-00002.safetensors",
597
+ "vision_model.vision_model.encoder.layers.3.mlp.fc2.weight": "model-00001-of-00002.safetensors",
598
+ "vision_model.vision_model.encoder.layers.3.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
599
+ "vision_model.vision_model.encoder.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
600
+ "vision_model.vision_model.encoder.layers.3.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
601
+ "vision_model.vision_model.encoder.layers.3.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
602
+ "vision_model.vision_model.encoder.layers.3.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
603
+ "vision_model.vision_model.encoder.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
604
+ "vision_model.vision_model.encoder.layers.3.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
605
+ "vision_model.vision_model.encoder.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
606
+ "vision_model.vision_model.encoder.layers.4.layer_norm1.bias": "model-00001-of-00002.safetensors",
607
+ "vision_model.vision_model.encoder.layers.4.layer_norm1.weight": "model-00001-of-00002.safetensors",
608
+ "vision_model.vision_model.encoder.layers.4.layer_norm2.bias": "model-00001-of-00002.safetensors",
609
+ "vision_model.vision_model.encoder.layers.4.layer_norm2.weight": "model-00001-of-00002.safetensors",
610
+ "vision_model.vision_model.encoder.layers.4.mlp.fc1.bias": "model-00001-of-00002.safetensors",
611
+ "vision_model.vision_model.encoder.layers.4.mlp.fc1.weight": "model-00001-of-00002.safetensors",
612
+ "vision_model.vision_model.encoder.layers.4.mlp.fc2.bias": "model-00001-of-00002.safetensors",
613
+ "vision_model.vision_model.encoder.layers.4.mlp.fc2.weight": "model-00001-of-00002.safetensors",
614
+ "vision_model.vision_model.encoder.layers.4.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
615
+ "vision_model.vision_model.encoder.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
616
+ "vision_model.vision_model.encoder.layers.4.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
617
+ "vision_model.vision_model.encoder.layers.4.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
618
+ "vision_model.vision_model.encoder.layers.4.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
619
+ "vision_model.vision_model.encoder.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
620
+ "vision_model.vision_model.encoder.layers.4.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
621
+ "vision_model.vision_model.encoder.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
622
+ "vision_model.vision_model.encoder.layers.5.layer_norm1.bias": "model-00001-of-00002.safetensors",
623
+ "vision_model.vision_model.encoder.layers.5.layer_norm1.weight": "model-00001-of-00002.safetensors",
624
+ "vision_model.vision_model.encoder.layers.5.layer_norm2.bias": "model-00001-of-00002.safetensors",
625
+ "vision_model.vision_model.encoder.layers.5.layer_norm2.weight": "model-00001-of-00002.safetensors",
626
+ "vision_model.vision_model.encoder.layers.5.mlp.fc1.bias": "model-00001-of-00002.safetensors",
627
+ "vision_model.vision_model.encoder.layers.5.mlp.fc1.weight": "model-00001-of-00002.safetensors",
628
+ "vision_model.vision_model.encoder.layers.5.mlp.fc2.bias": "model-00001-of-00002.safetensors",
629
+ "vision_model.vision_model.encoder.layers.5.mlp.fc2.weight": "model-00001-of-00002.safetensors",
630
+ "vision_model.vision_model.encoder.layers.5.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
631
+ "vision_model.vision_model.encoder.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
632
+ "vision_model.vision_model.encoder.layers.5.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
633
+ "vision_model.vision_model.encoder.layers.5.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
634
+ "vision_model.vision_model.encoder.layers.5.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
635
+ "vision_model.vision_model.encoder.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
636
+ "vision_model.vision_model.encoder.layers.5.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
637
+ "vision_model.vision_model.encoder.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
638
+ "vision_model.vision_model.encoder.layers.6.layer_norm1.bias": "model-00001-of-00002.safetensors",
639
+ "vision_model.vision_model.encoder.layers.6.layer_norm1.weight": "model-00001-of-00002.safetensors",
640
+ "vision_model.vision_model.encoder.layers.6.layer_norm2.bias": "model-00001-of-00002.safetensors",
641
+ "vision_model.vision_model.encoder.layers.6.layer_norm2.weight": "model-00001-of-00002.safetensors",
642
+ "vision_model.vision_model.encoder.layers.6.mlp.fc1.bias": "model-00001-of-00002.safetensors",
643
+ "vision_model.vision_model.encoder.layers.6.mlp.fc1.weight": "model-00001-of-00002.safetensors",
644
+ "vision_model.vision_model.encoder.layers.6.mlp.fc2.bias": "model-00001-of-00002.safetensors",
645
+ "vision_model.vision_model.encoder.layers.6.mlp.fc2.weight": "model-00001-of-00002.safetensors",
646
+ "vision_model.vision_model.encoder.layers.6.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
647
+ "vision_model.vision_model.encoder.layers.6.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
648
+ "vision_model.vision_model.encoder.layers.6.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
649
+ "vision_model.vision_model.encoder.layers.6.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
650
+ "vision_model.vision_model.encoder.layers.6.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
651
+ "vision_model.vision_model.encoder.layers.6.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
652
+ "vision_model.vision_model.encoder.layers.6.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
653
+ "vision_model.vision_model.encoder.layers.6.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
654
+ "vision_model.vision_model.encoder.layers.7.layer_norm1.bias": "model-00001-of-00002.safetensors",
655
+ "vision_model.vision_model.encoder.layers.7.layer_norm1.weight": "model-00001-of-00002.safetensors",
656
+ "vision_model.vision_model.encoder.layers.7.layer_norm2.bias": "model-00001-of-00002.safetensors",
657
+ "vision_model.vision_model.encoder.layers.7.layer_norm2.weight": "model-00001-of-00002.safetensors",
658
+ "vision_model.vision_model.encoder.layers.7.mlp.fc1.bias": "model-00001-of-00002.safetensors",
659
+ "vision_model.vision_model.encoder.layers.7.mlp.fc1.weight": "model-00001-of-00002.safetensors",
660
+ "vision_model.vision_model.encoder.layers.7.mlp.fc2.bias": "model-00001-of-00002.safetensors",
661
+ "vision_model.vision_model.encoder.layers.7.mlp.fc2.weight": "model-00001-of-00002.safetensors",
662
+ "vision_model.vision_model.encoder.layers.7.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
663
+ "vision_model.vision_model.encoder.layers.7.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
664
+ "vision_model.vision_model.encoder.layers.7.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
665
+ "vision_model.vision_model.encoder.layers.7.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
666
+ "vision_model.vision_model.encoder.layers.7.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
667
+ "vision_model.vision_model.encoder.layers.7.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
668
+ "vision_model.vision_model.encoder.layers.7.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
669
+ "vision_model.vision_model.encoder.layers.7.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
670
+ "vision_model.vision_model.encoder.layers.8.layer_norm1.bias": "model-00001-of-00002.safetensors",
671
+ "vision_model.vision_model.encoder.layers.8.layer_norm1.weight": "model-00001-of-00002.safetensors",
672
+ "vision_model.vision_model.encoder.layers.8.layer_norm2.bias": "model-00001-of-00002.safetensors",
673
+ "vision_model.vision_model.encoder.layers.8.layer_norm2.weight": "model-00001-of-00002.safetensors",
674
+ "vision_model.vision_model.encoder.layers.8.mlp.fc1.bias": "model-00001-of-00002.safetensors",
675
+ "vision_model.vision_model.encoder.layers.8.mlp.fc1.weight": "model-00001-of-00002.safetensors",
676
+ "vision_model.vision_model.encoder.layers.8.mlp.fc2.bias": "model-00001-of-00002.safetensors",
677
+ "vision_model.vision_model.encoder.layers.8.mlp.fc2.weight": "model-00001-of-00002.safetensors",
678
+ "vision_model.vision_model.encoder.layers.8.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
679
+ "vision_model.vision_model.encoder.layers.8.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
680
+ "vision_model.vision_model.encoder.layers.8.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
681
+ "vision_model.vision_model.encoder.layers.8.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
682
+ "vision_model.vision_model.encoder.layers.8.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
683
+ "vision_model.vision_model.encoder.layers.8.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
684
+ "vision_model.vision_model.encoder.layers.8.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
685
+ "vision_model.vision_model.encoder.layers.8.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
686
+ "vision_model.vision_model.encoder.layers.9.layer_norm1.bias": "model-00001-of-00002.safetensors",
687
+ "vision_model.vision_model.encoder.layers.9.layer_norm1.weight": "model-00001-of-00002.safetensors",
688
+ "vision_model.vision_model.encoder.layers.9.layer_norm2.bias": "model-00001-of-00002.safetensors",
689
+ "vision_model.vision_model.encoder.layers.9.layer_norm2.weight": "model-00001-of-00002.safetensors",
690
+ "vision_model.vision_model.encoder.layers.9.mlp.fc1.bias": "model-00001-of-00002.safetensors",
691
+ "vision_model.vision_model.encoder.layers.9.mlp.fc1.weight": "model-00001-of-00002.safetensors",
692
+ "vision_model.vision_model.encoder.layers.9.mlp.fc2.bias": "model-00001-of-00002.safetensors",
693
+ "vision_model.vision_model.encoder.layers.9.mlp.fc2.weight": "model-00001-of-00002.safetensors",
694
+ "vision_model.vision_model.encoder.layers.9.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
695
+ "vision_model.vision_model.encoder.layers.9.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
696
+ "vision_model.vision_model.encoder.layers.9.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
697
+ "vision_model.vision_model.encoder.layers.9.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
698
+ "vision_model.vision_model.encoder.layers.9.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
699
+ "vision_model.vision_model.encoder.layers.9.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
700
+ "vision_model.vision_model.encoder.layers.9.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
701
+ "vision_model.vision_model.encoder.layers.9.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
702
+ "vision_model.vision_model.head.attention.in_proj_bias": "model-00001-of-00002.safetensors",
703
+ "vision_model.vision_model.head.attention.in_proj_weight": "model-00001-of-00002.safetensors",
704
+ "vision_model.vision_model.head.attention.out_proj.bias": "model-00001-of-00002.safetensors",
705
+ "vision_model.vision_model.head.attention.out_proj.weight": "model-00001-of-00002.safetensors",
706
+ "vision_model.vision_model.head.layernorm.bias": "model-00001-of-00002.safetensors",
707
+ "vision_model.vision_model.head.layernorm.weight": "model-00001-of-00002.safetensors",
708
+ "vision_model.vision_model.head.mlp.fc1.bias": "model-00001-of-00002.safetensors",
709
+ "vision_model.vision_model.head.mlp.fc1.weight": "model-00001-of-00002.safetensors",
710
+ "vision_model.vision_model.head.mlp.fc2.bias": "model-00001-of-00002.safetensors",
711
+ "vision_model.vision_model.head.mlp.fc2.weight": "model-00001-of-00002.safetensors",
712
+ "vision_model.vision_model.head.probe": "model-00001-of-00002.safetensors",
713
+ "vision_model.vision_model.post_layernorm.bias": "model-00001-of-00002.safetensors",
714
+ "vision_model.vision_model.post_layernorm.weight": "model-00001-of-00002.safetensors"
715
+ }
716
+ }
modeling_eagle_chat.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Eagle2
3
+ # Copyright (c) 2025 NVIDIA
4
+ # Licensed under The Apache License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import warnings
8
+ from typing import Any, List, Optional, Tuple, Union
9
+
10
+ import torch.utils.checkpoint
11
+ import transformers
12
+ from torch import nn
13
+ from torch.nn import CrossEntropyLoss
14
+ from transformers import (AutoModel, GenerationConfig,
15
+ LlamaTokenizer)
16
+ from transformers.modeling_outputs import CausalLMOutputWithPast
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import ModelOutput, logging
19
+ from peft import LoraConfig, get_peft_model
20
+ from .configuration_eagle_chat import Eagle2ChatConfig
21
+ from .conversation import get_conv_template
22
+ from .modeling_siglip import SiglipVisionModel
23
+ from .modeling_qwen2 import Qwen2ForCausalLM
24
+ from .flash_attention import *
25
+ from .multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModel
26
+ from .multi_backbone_channel_concatenation_encoder import MultiBackboneChannelConcatenationVisionTower
27
+ from .configuration_multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModelConfig
28
+ from .siglip_vision_tower import SiglipVisionTower
29
+ from .convnext_encoder import ConvNextVisionTower
30
+ from .convnext import ConvNeXt
31
+ from .modeling_llama import LlamaForCausalLM
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ def version_cmp(v1, v2, op='eq'):
37
+ import operator
38
+
39
+ from packaging import version
40
+ op_func = getattr(operator, op)
41
+ return op_func(version.parse(v1), version.parse(v2))
42
+
43
+
44
+ class Eagle2ChatModel(PreTrainedModel):
45
+ config_class = Eagle2ChatConfig
46
+ main_input_name = 'pixel_values'
47
+ _no_split_modules = ['LlamaDecoderLayer']
48
+
49
+ def __init__(self, config: Eagle2ChatConfig, vision_model=None, language_model=None):
50
+ super().__init__(config)
51
+
52
+ assert version_cmp(transformers.__version__, '4.37.2', 'ge')
53
+ assert version_cmp(transformers.__version__, '4.39.2', 'le')
54
+ image_size = config.force_image_size or config.vision_config.image_size
55
+ if hasattr(config.vision_config, 'grid_size'):
56
+ grid_size = config.vision_config.grid_size
57
+ self.patch_size = 14
58
+ self.num_image_token = int((grid_size * config.downsample_ratio) ** 2)
59
+ else:
60
+ patch_size = config.vision_config.patch_size
61
+ self.patch_size = patch_size
62
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
63
+
64
+ self.select_layer = config.select_layer
65
+ self.template = config.template
66
+
67
+ self.downsample_ratio = config.downsample_ratio
68
+
69
+ logger.info(f'num_image_token: {self.num_image_token}')
70
+ if vision_model is not None:
71
+ self.vision_model = vision_model
72
+ else:
73
+ if config.vision_config.model_type == 'siglip_vision_model':
74
+ self.vision_model = SiglipVisionModel(config.vision_config)
75
+ elif config.vision_config.model_type.startswith("MOB"):
76
+ self.vision_model = MultiBackboneChannelConcatenationVisionModel(config.vision_config, config)
77
+
78
+ if language_model is not None:
79
+ self.language_model = language_model
80
+ else:
81
+ if config.llm_config.architectures[0] == 'LlamaForCausalLM':
82
+ self.language_model = LlamaForCausalLM(config.llm_config)
83
+ elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
84
+ self.language_model = Qwen2ForCausalLM(config.llm_config)
85
+ else:
86
+ raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
87
+
88
+ vit_hidden_size = config.vision_config.hidden_size
89
+ if vit_hidden_size == 'lazy_calculation':
90
+ # a hack for Mixture of Backbones
91
+ vit_hidden_size = self.vision_model.hidden_size
92
+ print("The lazy calculated hidden_size: {} .. ".format(vit_hidden_size))
93
+ llm_hidden_size = config.llm_config.hidden_size
94
+ self.moe_version_type = getattr(config.vision_config, 'moe_version_type', None)
95
+
96
+ if self.moe_version_type in ['seq_concat', 'feat_concat']:
97
+ raise NotImplementedError
98
+ elif self.moe_version_type == 'convnext_512_siglip_448':
99
+ convnext_hidden_size = vit_hidden_size['convnext']
100
+ siglip_hidden_size = vit_hidden_size['siglip']
101
+ feature_concat_hidden_size = convnext_hidden_size + siglip_hidden_size * int(1 / self.downsample_ratio) ** 2
102
+ self.mlp1 = nn.Sequential(
103
+ nn.LayerNorm(feature_concat_hidden_size),
104
+ nn.Linear(feature_concat_hidden_size, llm_hidden_size),
105
+ nn.GELU(),
106
+ nn.Linear(llm_hidden_size, llm_hidden_size)
107
+ )
108
+ else:
109
+ self.mlp1 = nn.Sequential(
110
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
111
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
112
+ nn.GELU(),
113
+ nn.Linear(llm_hidden_size, llm_hidden_size)
114
+ )
115
+ self.img_context_token_id = None
116
+ self.conv_template = get_conv_template(self.template)
117
+ self.system_message = self.conv_template.system_message
118
+
119
+ if config.use_backbone_lora:
120
+ self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
121
+
122
+ if config.use_llm_lora:
123
+ self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
124
+
125
+ def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
126
+ lora_config = LoraConfig(
127
+ r=r,
128
+ target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
129
+ lora_alpha=lora_alpha,
130
+ lora_dropout=lora_dropout,
131
+ )
132
+ self.vision_model = get_peft_model(self.vision_model, lora_config)
133
+ self.vision_model.print_trainable_parameters()
134
+
135
+ def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
136
+ lora_config = LoraConfig(
137
+ r=r,
138
+ target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
139
+ 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'],
140
+ lora_alpha=lora_alpha,
141
+ lora_dropout=lora_dropout,
142
+ task_type='CAUSAL_LM'
143
+ )
144
+ self.language_model = get_peft_model(self.language_model, lora_config)
145
+ self.language_model.enable_input_require_grads()
146
+ self.language_model.print_trainable_parameters()
147
+
148
+
149
+ def forward(
150
+ self,
151
+ pixel_values: torch.FloatTensor,
152
+ input_ids: torch.LongTensor = None,
153
+ attention_mask: Optional[torch.Tensor] = None,
154
+ position_ids: Optional[torch.LongTensor] = None,
155
+ image_flags: Optional[torch.LongTensor] = None,
156
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
157
+ labels: Optional[torch.LongTensor] = None,
158
+ use_cache: Optional[bool] = None,
159
+ output_attentions: Optional[bool] = None,
160
+ output_hidden_states: Optional[bool] = None,
161
+ return_dict: Optional[bool] = None,
162
+ num_patches_list: Optional[List[torch.Tensor]] = None,
163
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
164
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
165
+
166
+ image_flags = image_flags.squeeze(-1)
167
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
168
+
169
+
170
+ if self.moe_version_type in ['seq_concat', 'feat_concat'] and not isinstance(pixel_values, dict):
171
+ raise NotImplementedError
172
+ vit_embeds = self.extract_feature(pixel_values)
173
+
174
+ if not isinstance(image_flags, list):
175
+ image_flags = image_flags.squeeze(-1)
176
+ vit_embeds = vit_embeds[image_flags == 1]
177
+ if isinstance(pixel_values, dict):
178
+ # for MOE
179
+ vit_batch_size = sum(pixel_values['num_patches'])
180
+ else:
181
+ vit_batch_size = pixel_values.shape[0]
182
+
183
+ B, N, C = input_embeds.shape
184
+ input_embeds = input_embeds.reshape(B * N, C)
185
+
186
+ if torch.distributed.get_rank() == 0:
187
+ print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
188
+
189
+ input_ids = input_ids.reshape(B * N)
190
+ selected = (input_ids == self.img_context_token_id)
191
+ try:
192
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
193
+ except Exception as e:
194
+ vit_embeds = vit_embeds.reshape(-1, C)
195
+ print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
196
+ f'vit_embeds.shape={vit_embeds.shape}')
197
+ n_token = selected.sum()
198
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
199
+
200
+ input_embeds = input_embeds.reshape(B, N, C)
201
+
202
+ outputs = self.language_model(
203
+ inputs_embeds=input_embeds,
204
+ attention_mask=attention_mask,
205
+ position_ids=position_ids,
206
+ past_key_values=past_key_values,
207
+ use_cache=use_cache,
208
+ output_attentions=output_attentions,
209
+ output_hidden_states=output_hidden_states,
210
+ return_dict=return_dict,
211
+ )
212
+ logits = outputs.logits
213
+
214
+ loss = None
215
+ if labels is not None:
216
+ # Shift so that tokens < n predict n
217
+ shift_logits = logits[..., :-1, :].contiguous()
218
+ shift_labels = labels[..., 1:].contiguous()
219
+ # Flatten the tokens
220
+ loss_fct = CrossEntropyLoss()
221
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
222
+ shift_labels = shift_labels.view(-1)
223
+ # Enable model parallelism
224
+ shift_labels = shift_labels.to(shift_logits.device)
225
+ loss = loss_fct(shift_logits, shift_labels)
226
+
227
+ if not return_dict:
228
+ output = (logits,) + outputs[1:]
229
+ return (loss,) + output if loss is not None else output
230
+
231
+ return CausalLMOutputWithPast(
232
+ loss=loss,
233
+ logits=logits,
234
+ past_key_values=outputs.past_key_values,
235
+ hidden_states=outputs.hidden_states,
236
+ attentions=outputs.attentions,
237
+ )
238
+
239
+ def pixel_shuffle(self, x, scale_factor=0.5):
240
+ n, w, h, c = x.size()
241
+ # N, W, H, C --> N, W, H * scale, C // scale
242
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
243
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
244
+ x = x.permute(0, 2, 1, 3).contiguous()
245
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
246
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
247
+ int(c / (scale_factor * scale_factor)))
248
+ x = x.permute(0, 2, 1, 3).contiguous()
249
+ return x
250
+
251
+ def extract_feature(self, pixel_values):
252
+
253
+ """
254
+ """
255
+
256
+ if self.select_layer == -1:
257
+ vit_embeds = self.vision_model(
258
+ pixel_values=pixel_values,
259
+ output_hidden_states=False,
260
+ return_dict=True).last_hidden_state # torch.Size([B, 1025, 1024])
261
+
262
+ else:
263
+ vit_embeds = self.vision_model(
264
+ pixel_values=pixel_values,
265
+ output_hidden_states=True,
266
+ return_dict=True).hidden_states[self.select_layer]
267
+ if type(self.vision_model) == SiglipVisionModel:
268
+ pass
269
+ elif type(self.vision_model) == MultiBackboneChannelConcatenationVisionModel:
270
+ pass
271
+ else:
272
+ vit_embeds = vit_embeds[:, 1:, :] # torch.Size([B, 1024, 1024])
273
+
274
+ if self.training and self.neftune_alpha is not None:
275
+ vit_embeds = self.noised_embed(vit_embeds, self.neftune_alpha)
276
+
277
+ if self.moe_version_type in ['feat_concat', 'seq_concat']:
278
+ raise NotImplementedError
279
+ elif self.moe_version_type == 'convnext_512_siglip_448':
280
+ siglip_embeds = vit_embeds['siglip']
281
+ convnext_embeds = vit_embeds['convnext']
282
+ h = w = int(siglip_embeds.shape[1] ** 0.5)
283
+ siglip_embeds = siglip_embeds.reshape(siglip_embeds.shape[0], h, w, -1)
284
+ siglip_embeds = self.pixel_shuffle(siglip_embeds, scale_factor=self.downsample_ratio)
285
+ siglip_embeds = siglip_embeds.reshape(siglip_embeds.shape[0], -1, siglip_embeds.shape[-1])
286
+ vit_embeds = self.mlp1(torch.cat([siglip_embeds, convnext_embeds], dim=-1))
287
+ else:
288
+ h = w = int(vit_embeds.shape[1] ** 0.5)
289
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
290
+
291
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
292
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
293
+ vit_embeds = self.mlp1(vit_embeds)#.to(pixel_values.device)
294
+
295
+ return vit_embeds
296
+
297
+ def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
298
+ history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
299
+ IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
300
+ if history is not None or return_history:
301
+ print('Now multi-turn chat is not supported in batch_chat.')
302
+ raise NotImplementedError
303
+
304
+ if image_counts is not None:
305
+ num_patches_list = image_counts
306
+ print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
307
+
308
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
309
+ self.img_context_token_id = img_context_token_id
310
+
311
+ if verbose and pixel_values is not None:
312
+ image_bs = pixel_values.shape[0]
313
+ print(f'dynamic ViT batch size: {image_bs}')
314
+
315
+ queries = []
316
+ for idx, num_patches in enumerate(num_patches_list):
317
+ question = questions[idx]
318
+ if pixel_values is not None and '<image>' not in question:
319
+ question = '<image>\n' + question
320
+ template = get_conv_template(self.template)
321
+ template.append_message(template.roles[0], question)
322
+ template.append_message(template.roles[1], None)
323
+ query = template.get_prompt()
324
+
325
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
326
+ query = query.replace('<image>', image_tokens, 1)
327
+ queries.append(query)
328
+
329
+ tokenizer.padding_side = 'left'
330
+ model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
331
+ input_ids = model_inputs['input_ids'].cuda()
332
+ attention_mask = model_inputs['attention_mask'].cuda()
333
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
334
+ generation_config['eos_token_id'] = eos_token_id
335
+ generation_output = self.generate(
336
+ pixel_values=pixel_values,
337
+ input_ids=input_ids,
338
+ attention_mask=attention_mask,
339
+ **generation_config
340
+ )
341
+ responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
342
+ responses = [response.split(template.sep)[0].strip() for response in responses]
343
+ return responses
344
+
345
+ def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
346
+ num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
347
+ verbose=False, llm_only=False):
348
+
349
+ if history is None and pixel_values is not None and '<image>' not in question:
350
+ question = '<image>\n' + question
351
+
352
+ if num_patches_list is None:
353
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
354
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
355
+
356
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
357
+ self.img_context_token_id = img_context_token_id
358
+
359
+ template = get_conv_template(self.template)
360
+ template.system_message = self.system_message
361
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
362
+
363
+ history = [] if history is None else history
364
+ for (old_question, old_answer) in history:
365
+ template.append_message(template.roles[0], old_question)
366
+ template.append_message(template.roles[1], old_answer)
367
+ template.append_message(template.roles[0], question)
368
+ template.append_message(template.roles[1], None)
369
+ query = template.get_prompt()
370
+
371
+ if verbose and pixel_values is not None:
372
+ image_bs = pixel_values.shape[0]
373
+ print(f'dynamic ViT batch size: {image_bs}')
374
+
375
+ for num_patches in num_patches_list:
376
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
377
+ if llm_only:
378
+ query = query.replace('<image>', '', 1)
379
+ else:
380
+ query = query.replace('<image>', image_tokens, 1)
381
+
382
+ model_inputs = tokenizer(query, return_tensors='pt')
383
+ input_ids = model_inputs['input_ids'].cuda()
384
+ attention_mask = model_inputs['attention_mask'].cuda()
385
+ generation_config['eos_token_id'] = eos_token_id
386
+ if self.moe_version_type is not None and self.moe_version_type != 'all_tiling' and self.moe_version_type != 'convnext_512_siglip_448':
387
+ pixel_values = {
388
+ 'pixel_values': pixel_values,
389
+ 'num_patches': num_patches_list # num patch of each image.
390
+ }
391
+ generation_output = self.generate(
392
+ pixel_values=pixel_values,
393
+ input_ids=input_ids,
394
+ attention_mask=attention_mask,
395
+ **generation_config
396
+ )
397
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
398
+ response = response.split(template.sep)[0].strip()
399
+ history.append((question, response))
400
+ if return_history:
401
+ return response, history
402
+ else:
403
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
404
+ query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
405
+ if verbose:
406
+ print(query_to_print, response)
407
+ return response
408
+
409
+ @torch.no_grad()
410
+ def generate(
411
+ self,
412
+ pixel_values: Optional[torch.FloatTensor] = None,
413
+ input_ids: Optional[torch.FloatTensor] = None,
414
+ attention_mask: Optional[torch.LongTensor] = None,
415
+ visual_features: Optional[torch.FloatTensor] = None,
416
+ generation_config: Optional[GenerationConfig] = None,
417
+ output_hidden_states: Optional[bool] = None,
418
+ return_dict: Optional[bool] = None,
419
+ **generate_kwargs,
420
+ ) -> torch.LongTensor:
421
+
422
+ assert self.img_context_token_id is not None
423
+ if pixel_values is not None:
424
+ if visual_features is not None:
425
+ vit_embeds = visual_features
426
+ else:
427
+ vit_embeds = self.extract_feature(pixel_values)
428
+
429
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
430
+ B, N, C = input_embeds.shape
431
+ input_embeds = input_embeds.reshape(B * N, C)
432
+
433
+ input_ids = input_ids.reshape(B * N)
434
+ selected = (input_ids == self.img_context_token_id)
435
+ assert selected.sum() != 0
436
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
437
+
438
+ input_embeds = input_embeds.reshape(B, N, C)
439
+ else:
440
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
441
+
442
+ outputs = self.language_model.generate(
443
+ inputs_embeds=input_embeds,
444
+ attention_mask=attention_mask,
445
+ generation_config=generation_config,
446
+ output_hidden_states=output_hidden_states,
447
+ return_dict=return_dict,
448
+ use_cache=True,
449
+ **generate_kwargs,
450
+ )
451
+
452
+ return outputs
453
+
454
+ def get_input_embeddings(self):
455
+ return self.language_model.get_input_embeddings()
456
+
457
+ def get_output_embeddings(self):
458
+ return self.language_model.get_output_embeddings()
modeling_llama.py ADDED
@@ -0,0 +1,1774 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import math
22
+ import warnings
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+ import torch.utils.checkpoint
28
+ from torch import nn
29
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
+
31
+ from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import Cache, DynamicCache
33
+ from transformers.modeling_attn_mask_utils import (
34
+ AttentionMaskConverter,
35
+ _prepare_4d_attention_mask,
36
+ _prepare_4d_causal_attention_mask,
37
+ _prepare_4d_causal_attention_mask_for_sdpa,
38
+ )
39
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
40
+ from transformers.modeling_utils import PreTrainedModel
41
+ from transformers.configuration_utils import PretrainedConfig
42
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
43
+ from transformers.utils import (
44
+ add_start_docstrings,
45
+ add_start_docstrings_to_model_forward,
46
+ is_flash_attn_2_available,
47
+ is_flash_attn_greater_or_equal_2_10,
48
+ logging,
49
+ replace_return_docstrings,
50
+ )
51
+ from transformers.utils.import_utils import is_torch_fx_available
52
+ from .configuration_llama import LlamaConfig
53
+
54
+
55
+ if is_flash_attn_2_available():
56
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
57
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
58
+
59
+
60
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
61
+ # It means that the function will not be traced through and simply appear as a node in the graph.
62
+ if is_torch_fx_available():
63
+ if not is_torch_greater_or_equal_than_1_13:
64
+ import torch.fx
65
+
66
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
67
+
68
+
69
+ logger = logging.get_logger(__name__)
70
+
71
+ _CONFIG_FOR_DOC = "LlamaConfig"
72
+
73
+
74
+ def _get_unpad_data(attention_mask):
75
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
76
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
77
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
78
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
79
+ return (
80
+ indices,
81
+ cu_seqlens,
82
+ max_seqlen_in_batch,
83
+ )
84
+
85
+ def _get_unpad_data_packing(attention_mask, sub_sample_lengths):
86
+ seqlens_in_batch = []
87
+ for i, per_sub_sample_lengths in enumerate(sub_sample_lengths):
88
+ if (attention_mask[i]==0).sum() == per_sub_sample_lengths[-1]:
89
+ per_sub_sample_lengths = per_sub_sample_lengths[:-1]
90
+ seqlens_in_batch.extend(per_sub_sample_lengths)
91
+ seqlens_in_batch = torch.tensor(seqlens_in_batch, device=attention_mask.device, dtype=torch.int32)
92
+
93
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
94
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
95
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
96
+ return (
97
+ indices,
98
+ cu_seqlens,
99
+ max_seqlen_in_batch,
100
+ )
101
+
102
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
103
+ warnings.warn(
104
+ "Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
105
+ )
106
+ return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
107
+
108
+
109
+ def _make_causal_mask(
110
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
111
+ ):
112
+ warnings.warn(
113
+ "Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttentionMaskConverter._make_causal_mask"
114
+ )
115
+ return AttentionMaskConverter._make_causal_mask(
116
+ input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
117
+ )
118
+
119
+
120
+ class LlamaRMSNorm(nn.Module):
121
+ def __init__(self, hidden_size, eps=1e-6):
122
+ """
123
+ LlamaRMSNorm is equivalent to T5LayerNorm
124
+ """
125
+ super().__init__()
126
+ self.weight = nn.Parameter(torch.ones(hidden_size))
127
+ self.variance_epsilon = eps
128
+
129
+ def forward(self, hidden_states):
130
+ input_dtype = hidden_states.dtype
131
+ hidden_states = hidden_states.to(torch.float32)
132
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
133
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
134
+ return self.weight * hidden_states.to(input_dtype)
135
+
136
+
137
+ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
138
+
139
+ def _compute_default_rope_parameters(
140
+ config: Optional[PretrainedConfig] = None,
141
+ device: Optional["torch.device"] = None,
142
+ seq_len: Optional[int] = None,
143
+ **rope_kwargs,
144
+ ) -> Tuple["torch.Tensor", float]:
145
+ """
146
+ Computes the inverse frequencies according to the original RoPE implementation
147
+ Args:
148
+ config ([`~transformers.PretrainedConfig`]):
149
+ The model configuration.
150
+ device (`torch.device`):
151
+ The device to use for initialization of the inverse frequencies.
152
+ seq_len (`int`, *optional*):
153
+ The current sequence length. Unused for this type of RoPE.
154
+ rope_kwargs (`Dict`, *optional*):
155
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
156
+ Returns:
157
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
158
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
159
+ """
160
+ if config is not None and len(rope_kwargs) > 0:
161
+ raise ValueError(
162
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
163
+ f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
164
+ )
165
+ if len(rope_kwargs) > 0:
166
+ base = rope_kwargs["base"]
167
+ dim = rope_kwargs["dim"]
168
+ elif config is not None:
169
+ base = config.rope_theta
170
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
171
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
172
+ dim = int(head_dim * partial_rotary_factor)
173
+
174
+ attention_factor = 1.0 # Unused in this type of RoPE
175
+
176
+ # Compute the inverse frequencies
177
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
178
+ return inv_freq, attention_factor
179
+
180
+
181
+ def _compute_llama3_parameters(
182
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
183
+ ) -> Tuple["torch.Tensor", float]:
184
+ """
185
+ Computes the inverse frequencies for llama 3.1.
186
+
187
+ Args:
188
+ config ([`~transformers.PretrainedConfig`]):
189
+ The model configuration.
190
+ device (`torch.device`):
191
+ The device to use for initialization of the inverse frequencies.
192
+ seq_len (`int`, *optional*):
193
+ The current sequence length. Unused for this type of RoPE.
194
+ rope_kwargs (`Dict`, *optional*):
195
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
196
+ Returns:
197
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
198
+ post-processing scaling factor applied to the computed cos/sin.
199
+ """
200
+ # Gets the default RoPE parameters
201
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
202
+
203
+ factor = config.rope_scaling["factor"] # `8` in the original implementation
204
+ low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
205
+ high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
206
+ old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
207
+
208
+ low_freq_wavelen = old_context_len / low_freq_factor
209
+ high_freq_wavelen = old_context_len / high_freq_factor
210
+
211
+ wavelen = 2 * math.pi / inv_freq
212
+ # wavelen < high_freq_wavelen: do nothing
213
+ # wavelen > low_freq_wavelen: divide by factor
214
+ inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
215
+ # otherwise: interpolate between the two, using a smooth factor
216
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
217
+ smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
218
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
219
+ inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
220
+
221
+ return inv_freq_llama, attention_factor
222
+
223
+
224
+ class LlamaRotaryEmbedding(nn.Module):
225
+ def __init__(
226
+ self,
227
+ config: LlamaConfig,
228
+ device=None,
229
+ ):
230
+ super().__init__()
231
+ self.rope_kwargs = {}
232
+ # BC: "rope_type" was originally "type"
233
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
234
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
235
+ else:
236
+ self.rope_type = "default"
237
+ self.max_seq_len_cached = config.max_position_embeddings
238
+ self.original_max_seq_len = config.max_position_embeddings
239
+
240
+ self.config = config
241
+ self.rope_init_fn = _compute_llama3_parameters
242
+
243
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
244
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
245
+ self.original_inv_freq = self.inv_freq
246
+
247
+ def _dynamic_frequency_update(self, position_ids, device):
248
+ """
249
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
250
+ 1 - growing beyond the cached sequence length (allow scaling)
251
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
252
+ """
253
+ seq_len = torch.max(position_ids) + 1
254
+ if seq_len > self.max_seq_len_cached: # growth
255
+ inv_freq, self.attention_scaling = self.rope_init_fn(
256
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
257
+ )
258
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
259
+ self.max_seq_len_cached = seq_len
260
+
261
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
262
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
263
+ self.max_seq_len_cached = self.original_max_seq_len
264
+
265
+ @torch.no_grad()
266
+ def forward(self, x, position_ids):
267
+ if "dynamic" in self.rope_type:
268
+ self._dynamic_frequency_update(position_ids, device=x.device)
269
+
270
+ # Core RoPE block
271
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
272
+ position_ids_expanded = position_ids[:, None, :].float()
273
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
274
+ device_type = x.device.type
275
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
276
+ with torch.autocast(device_type=device_type, enabled=False):
277
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
278
+ emb = torch.cat((freqs, freqs), dim=-1)
279
+ cos = emb.cos()
280
+ sin = emb.sin()
281
+
282
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
283
+ cos = cos * self.attention_scaling
284
+ sin = sin * self.attention_scaling
285
+
286
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
287
+
288
+
289
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
290
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
291
+
292
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
293
+ self.scaling_factor = scaling_factor
294
+ super().__init__(dim, max_position_embeddings, base, device)
295
+
296
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
297
+ self.max_seq_len_cached = seq_len
298
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
299
+ t = t / self.scaling_factor
300
+
301
+ freqs = torch.outer(t, self.inv_freq)
302
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
303
+ emb = torch.cat((freqs, freqs), dim=-1)
304
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
305
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
306
+
307
+
308
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
309
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
310
+
311
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
312
+ self.scaling_factor = scaling_factor
313
+ super().__init__(dim, max_position_embeddings, base, device)
314
+
315
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
316
+ self.max_seq_len_cached = seq_len
317
+
318
+ if seq_len > self.max_position_embeddings:
319
+ base = self.base * (
320
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
321
+ ) ** (self.dim / (self.dim - 2))
322
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
323
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
324
+
325
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
326
+
327
+ freqs = torch.outer(t, self.inv_freq)
328
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
329
+ emb = torch.cat((freqs, freqs), dim=-1)
330
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
331
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
332
+
333
+
334
+ def rotate_half(x):
335
+ """Rotates half the hidden dims of the input."""
336
+ x1 = x[..., : x.shape[-1] // 2]
337
+ x2 = x[..., x.shape[-1] // 2 :]
338
+ return torch.cat((-x2, x1), dim=-1)
339
+
340
+
341
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
342
+ """Applies Rotary Position Embedding to the query and key tensors.
343
+
344
+ Args:
345
+ q (`torch.Tensor`): The query tensor.
346
+ k (`torch.Tensor`): The key tensor.
347
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
348
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
349
+ position_ids (`torch.Tensor`, *optional*):
350
+ Deprecated and unused.
351
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
352
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
353
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
354
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
355
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
356
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
357
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
358
+ Returns:
359
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
360
+ """
361
+ cos = cos.unsqueeze(unsqueeze_dim)
362
+ sin = sin.unsqueeze(unsqueeze_dim)
363
+ q_embed = (q * cos) + (rotate_half(q) * sin)
364
+ k_embed = (k * cos) + (rotate_half(k) * sin)
365
+ return q_embed, k_embed
366
+
367
+
368
+ class LlamaMLP(nn.Module):
369
+ def __init__(self, config):
370
+ super().__init__()
371
+ self.config = config
372
+ self.hidden_size = config.hidden_size
373
+ self.intermediate_size = config.intermediate_size
374
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
375
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
376
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
377
+ self.act_fn = ACT2FN[config.hidden_act]
378
+
379
+ def forward(self, x):
380
+ if self.config.pretraining_tp > 1:
381
+ slice = self.intermediate_size // self.config.pretraining_tp
382
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
383
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
384
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
385
+
386
+ gate_proj = torch.cat(
387
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
388
+ )
389
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
390
+
391
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
392
+ down_proj = [
393
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
394
+ ]
395
+ down_proj = sum(down_proj)
396
+ else:
397
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
398
+
399
+ return down_proj
400
+
401
+
402
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
403
+ """
404
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
405
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
406
+ """
407
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
408
+ if n_rep == 1:
409
+ return hidden_states
410
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
411
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
412
+
413
+
414
+ class LlamaAttention(nn.Module):
415
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
416
+
417
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
418
+ super().__init__()
419
+ self.config = config
420
+ self.layer_idx = layer_idx
421
+ if layer_idx is None:
422
+ logger.warning_once(
423
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
424
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
425
+ "when creating this class."
426
+ )
427
+
428
+ self.attention_dropout = config.attention_dropout
429
+ self.hidden_size = config.hidden_size
430
+ self.num_heads = config.num_attention_heads
431
+ self.head_dim = self.hidden_size // self.num_heads
432
+ self.num_key_value_heads = config.num_key_value_heads
433
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
434
+ self.max_position_embeddings = config.max_position_embeddings
435
+ self.rope_theta = config.rope_theta
436
+ self.is_causal = True
437
+
438
+ if (self.head_dim * self.num_heads) != self.hidden_size:
439
+ raise ValueError(
440
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
441
+ f" and `num_heads`: {self.num_heads})."
442
+ )
443
+
444
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
445
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
446
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
447
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
448
+ self._init_rope()
449
+
450
+ def _init_rope(self):
451
+ if self.config.rope_scaling is None:
452
+ self.rotary_emb = LlamaRotaryEmbedding(
453
+ self.config
454
+ )
455
+ else:
456
+ scaling_type = self.config.rope_scaling["type"]
457
+ scaling_factor = self.config.rope_scaling["factor"]
458
+ if scaling_type == "linear":
459
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
460
+ self.head_dim,
461
+ max_position_embeddings=self.max_position_embeddings,
462
+ scaling_factor=scaling_factor,
463
+ base=self.rope_theta,
464
+ )
465
+ elif scaling_type == "dynamic":
466
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
467
+ self.head_dim,
468
+ max_position_embeddings=self.max_position_embeddings,
469
+ scaling_factor=scaling_factor,
470
+ base=self.rope_theta,
471
+ )
472
+ elif scaling_type == 'llama3':
473
+ self.rotary_emb = LlamaRotaryEmbedding(
474
+ self.config
475
+ )
476
+ else:
477
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
478
+
479
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
480
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
481
+
482
+ def forward(
483
+ self,
484
+ hidden_states: torch.Tensor,
485
+ attention_mask: Optional[torch.Tensor] = None,
486
+ position_ids: Optional[torch.LongTensor] = None,
487
+ past_key_value: Optional[Cache] = None,
488
+ output_attentions: bool = False,
489
+ use_cache: bool = False,
490
+ **kwargs,
491
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
492
+ if "padding_mask" in kwargs:
493
+ warnings.warn(
494
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
495
+ )
496
+
497
+ bsz, q_len, _ = hidden_states.size()
498
+
499
+ if self.config.pretraining_tp > 1:
500
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
501
+ query_slices = self.q_proj.weight.split(
502
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
503
+ )
504
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
505
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
506
+
507
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
508
+ query_states = torch.cat(query_states, dim=-1)
509
+
510
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
511
+ key_states = torch.cat(key_states, dim=-1)
512
+
513
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
514
+ value_states = torch.cat(value_states, dim=-1)
515
+
516
+ else:
517
+ query_states = self.q_proj(hidden_states)
518
+ key_states = self.k_proj(hidden_states)
519
+ value_states = self.v_proj(hidden_states)
520
+
521
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
522
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
523
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
524
+
525
+ kv_seq_len = key_states.shape[-2]
526
+ if past_key_value is not None:
527
+ if self.layer_idx is None:
528
+ raise ValueError(
529
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
530
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
531
+ "with a layer index."
532
+ )
533
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
534
+ cos, sin = self.rotary_emb(value_states, position_ids)
535
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
536
+
537
+ if past_key_value is not None:
538
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
539
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
540
+
541
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
542
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
543
+
544
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
545
+
546
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
547
+ raise ValueError(
548
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
549
+ f" {attn_weights.size()}"
550
+ )
551
+
552
+ if attention_mask is not None:
553
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
554
+ raise ValueError(
555
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
556
+ )
557
+ attn_weights = attn_weights + attention_mask
558
+
559
+ # upcast attention to fp32
560
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
561
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
562
+ attn_output = torch.matmul(attn_weights, value_states)
563
+
564
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
565
+ raise ValueError(
566
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
567
+ f" {attn_output.size()}"
568
+ )
569
+
570
+ attn_output = attn_output.transpose(1, 2).contiguous()
571
+
572
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
573
+
574
+ if self.config.pretraining_tp > 1:
575
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
576
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
577
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
578
+ else:
579
+ attn_output = self.o_proj(attn_output)
580
+
581
+ if not output_attentions:
582
+ attn_weights = None
583
+
584
+ return attn_output, attn_weights, past_key_value
585
+
586
+
587
+ class LlamaFlashAttention2(LlamaAttention):
588
+ """
589
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
590
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
591
+ flash attention and deal with padding tokens in case the input contains any of them.
592
+ """
593
+
594
+ def __init__(self, *args, **kwargs):
595
+ super().__init__(*args, **kwargs)
596
+
597
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
598
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
599
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
600
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
601
+
602
+ def forward(
603
+ self,
604
+ hidden_states: torch.Tensor,
605
+ attention_mask: Optional[torch.LongTensor] = None,
606
+ position_ids: Optional[torch.LongTensor] = None,
607
+ past_key_value: Optional[Cache] = None,
608
+ output_attentions: bool = False,
609
+ use_cache: bool = False,
610
+ **kwargs,
611
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
612
+ # LlamaFlashAttention2 attention does not support output_attentions
613
+ if "padding_mask" in kwargs:
614
+ warnings.warn(
615
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
616
+ )
617
+
618
+ # overwrite attention_mask with padding_mask
619
+ attention_mask = kwargs.pop("padding_mask")
620
+
621
+ output_attentions = False
622
+
623
+ bsz, q_len, _ = hidden_states.size()
624
+
625
+ query_states = self.q_proj(hidden_states)
626
+ key_states = self.k_proj(hidden_states)
627
+ value_states = self.v_proj(hidden_states)
628
+
629
+ # Flash attention requires the input to have the shape
630
+ # batch_size x seq_length x head_dim x hidden_dim
631
+ # therefore we just need to keep the original shape
632
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
633
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
634
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
635
+
636
+ kv_seq_len = key_states.shape[-2]
637
+ if past_key_value is not None:
638
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
639
+ cos, sin = self.rotary_emb(value_states, position_ids)
640
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
641
+
642
+ if past_key_value is not None:
643
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
644
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
645
+
646
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
647
+ # to be able to avoid many of these transpose/reshape/view.
648
+ query_states = query_states.transpose(1, 2)
649
+ key_states = key_states.transpose(1, 2)
650
+ value_states = value_states.transpose(1, 2)
651
+
652
+ dropout_rate = self.attention_dropout if self.training else 0.0
653
+
654
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
655
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
656
+ # cast them back in the correct dtype just to be sure everything works as expected.
657
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
658
+ # in fp32. (LlamaRMSNorm handles it correctly)
659
+
660
+ input_dtype = query_states.dtype
661
+ if input_dtype == torch.float32:
662
+ if torch.is_autocast_enabled():
663
+ target_dtype = torch.get_autocast_gpu_dtype()
664
+ # Handle the case where the model is quantized
665
+ elif hasattr(self.config, "_pre_quantization_dtype"):
666
+ target_dtype = self.config._pre_quantization_dtype
667
+ else:
668
+ target_dtype = self.q_proj.weight.dtype
669
+
670
+ logger.warning_once(
671
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
672
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
673
+ f" {target_dtype}."
674
+ )
675
+
676
+ query_states = query_states.to(target_dtype)
677
+ key_states = key_states.to(target_dtype)
678
+ value_states = value_states.to(target_dtype)
679
+
680
+ attn_output = self._flash_attention_forward(
681
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
682
+ )
683
+
684
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
685
+ attn_output = self.o_proj(attn_output)
686
+
687
+ if not output_attentions:
688
+ attn_weights = None
689
+
690
+ return attn_output, attn_weights, past_key_value
691
+
692
+ def _flash_attention_forward(
693
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
694
+ ):
695
+ """
696
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
697
+ first unpad the input, then computes the attention scores and pad the final attention scores.
698
+
699
+ Args:
700
+ query_states (`torch.Tensor`):
701
+ Input query states to be passed to Flash Attention API
702
+ key_states (`torch.Tensor`):
703
+ Input key states to be passed to Flash Attention API
704
+ value_states (`torch.Tensor`):
705
+ Input value states to be passed to Flash Attention API
706
+ attention_mask (`torch.Tensor`):
707
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
708
+ position of padding tokens and 1 for the position of non-padding tokens.
709
+ dropout (`int`, *optional*):
710
+ Attention dropout
711
+ softmax_scale (`float`, *optional*):
712
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
713
+ """
714
+ if not self._flash_attn_uses_top_left_mask:
715
+ causal = self.is_causal
716
+ else:
717
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
718
+ causal = self.is_causal and query_length != 1
719
+
720
+ # Contains at least one padding token in the sequence
721
+ if attention_mask is not None:
722
+ batch_size = query_states.shape[0]
723
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
724
+ query_states, key_states, value_states, attention_mask, query_length
725
+ )
726
+
727
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
728
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
729
+
730
+ attn_output_unpad = flash_attn_varlen_func(
731
+ query_states,
732
+ key_states,
733
+ value_states,
734
+ cu_seqlens_q=cu_seqlens_q,
735
+ cu_seqlens_k=cu_seqlens_k,
736
+ max_seqlen_q=max_seqlen_in_batch_q,
737
+ max_seqlen_k=max_seqlen_in_batch_k,
738
+ dropout_p=dropout,
739
+ softmax_scale=softmax_scale,
740
+ causal=causal,
741
+ )
742
+
743
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
744
+ else:
745
+ attn_output = flash_attn_func(
746
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
747
+ )
748
+
749
+ return attn_output
750
+
751
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
752
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
753
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
754
+
755
+ key_layer = index_first_axis(
756
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
757
+ )
758
+ value_layer = index_first_axis(
759
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
760
+ )
761
+ if query_length == kv_seq_len:
762
+ query_layer = index_first_axis(
763
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
764
+ )
765
+ cu_seqlens_q = cu_seqlens_k
766
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
767
+ indices_q = indices_k
768
+ elif query_length == 1:
769
+ max_seqlen_in_batch_q = 1
770
+ cu_seqlens_q = torch.arange(
771
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
772
+ ) # There is a memcpy here, that is very bad.
773
+ indices_q = cu_seqlens_q[:-1]
774
+ query_layer = query_layer.squeeze(1)
775
+ else:
776
+ # The -q_len: slice assumes left padding.
777
+ attention_mask = attention_mask[:, -query_length:]
778
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
779
+
780
+ return (
781
+ query_layer,
782
+ key_layer,
783
+ value_layer,
784
+ indices_q,
785
+ (cu_seqlens_q, cu_seqlens_k),
786
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
787
+ )
788
+
789
+
790
+ class LlamaFlashAttention2_packing(LlamaAttention):
791
+ """
792
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
793
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
794
+ flash attention and deal with padding tokens in case the input contains any of them.
795
+ """
796
+
797
+ def __init__(self, *args, **kwargs):
798
+ super().__init__(*args, **kwargs)
799
+
800
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
801
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
802
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
803
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
804
+
805
+ def forward(
806
+ self,
807
+ hidden_states: torch.Tensor,
808
+ attention_mask: Optional[torch.LongTensor] = None,
809
+ position_ids: Optional[torch.LongTensor] = None,
810
+ past_key_value: Optional[Cache] = None,
811
+ output_attentions: bool = False,
812
+ use_cache: bool = False,
813
+ sub_sample_lengths=None,
814
+ **kwargs,
815
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
816
+ # LlamaFlashAttention2 attention does not support output_attentions
817
+ if "padding_mask" in kwargs:
818
+ warnings.warn(
819
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
820
+ )
821
+
822
+ # overwrite attention_mask with padding_mask
823
+ attention_mask = kwargs.pop("padding_mask")
824
+
825
+ output_attentions = False
826
+
827
+ bsz, q_len, _ = hidden_states.size()
828
+
829
+ query_states = self.q_proj(hidden_states)
830
+ key_states = self.k_proj(hidden_states)
831
+ value_states = self.v_proj(hidden_states)
832
+
833
+ # Flash attention requires the input to have the shape
834
+ # batch_size x seq_length x head_dim x hidden_dim
835
+ # therefore we just need to keep the original shape
836
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
837
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
838
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
839
+
840
+ kv_seq_len = key_states.shape[-2]
841
+ if past_key_value is not None:
842
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
843
+ cos, sin = self.rotary_emb(value_states, position_ids)
844
+ if sub_sample_lengths is not None:
845
+ packing_position_ids = []
846
+ for b in range(bsz):
847
+ each_sum_sample_lengths = sub_sample_lengths[b]
848
+ packing_position_ids.append(torch.cat([torch.arange(each) for each in each_sum_sample_lengths]))
849
+ packing_position_ids = torch.stack(packing_position_ids)
850
+ packing_position_ids.to(query_states.device)
851
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, packing_position_ids)
852
+ else:
853
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
854
+
855
+ if past_key_value is not None:
856
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
857
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
858
+
859
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
860
+ # to be able to avoid many of these transpose/reshape/view.
861
+ query_states = query_states.transpose(1, 2)
862
+ key_states = key_states.transpose(1, 2)
863
+ value_states = value_states.transpose(1, 2)
864
+
865
+ dropout_rate = self.attention_dropout if self.training else 0.0
866
+
867
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
868
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
869
+ # cast them back in the correct dtype just to be sure everything works as expected.
870
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
871
+ # in fp32. (LlamaRMSNorm handles it correctly)
872
+
873
+ input_dtype = query_states.dtype
874
+ if input_dtype == torch.float32:
875
+ if torch.is_autocast_enabled():
876
+ target_dtype = torch.get_autocast_gpu_dtype()
877
+ # Handle the case where the model is quantized
878
+ elif hasattr(self.config, "_pre_quantization_dtype"):
879
+ target_dtype = self.config._pre_quantization_dtype
880
+ else:
881
+ target_dtype = self.q_proj.weight.dtype
882
+
883
+ logger.warning_once(
884
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
885
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
886
+ f" {target_dtype}."
887
+ )
888
+
889
+ query_states = query_states.to(target_dtype)
890
+ key_states = key_states.to(target_dtype)
891
+ value_states = value_states.to(target_dtype)
892
+
893
+ attn_output = self._flash_attention_forward(
894
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, sub_sample_lengths=sub_sample_lengths)
895
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
896
+ attn_output = self.o_proj(attn_output)
897
+
898
+ if not output_attentions:
899
+ attn_weights = None
900
+
901
+ return attn_output, attn_weights, past_key_value
902
+
903
+ def _flash_attention_forward(
904
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, sub_sample_lengths=None,
905
+ ):
906
+ """
907
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
908
+ first unpad the input, then computes the attention scores and pad the final attention scores.
909
+
910
+ Args:
911
+ query_states (`torch.Tensor`):
912
+ Input query states to be passed to Flash Attention API
913
+ key_states (`torch.Tensor`):
914
+ Input key states to be passed to Flash Attention API
915
+ value_states (`torch.Tensor`):
916
+ Input value states to be passed to Flash Attention API
917
+ attention_mask (`torch.Tensor`):
918
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
919
+ position of padding tokens and 1 for the position of non-padding tokens.
920
+ dropout (`int`, *optional*):
921
+ Attention dropout
922
+ softmax_scale (`float`, *optional*):
923
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
924
+ """
925
+ if not self._flash_attn_uses_top_left_mask:
926
+ causal = self.is_causal
927
+ else:
928
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
929
+ causal = self.is_causal and query_length != 1
930
+
931
+ # Contains at least one padding token in the sequence
932
+ if attention_mask is not None:
933
+ batch_size = query_states.shape[0]
934
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input_packing(
935
+ query_states, key_states, value_states, attention_mask, query_length, sub_sample_lengths
936
+ )
937
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
938
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
939
+
940
+ attn_output_unpad = flash_attn_varlen_func(
941
+ query_states,
942
+ key_states,
943
+ value_states,
944
+ cu_seqlens_q=cu_seqlens_q,
945
+ cu_seqlens_k=cu_seqlens_k,
946
+ max_seqlen_q=max_seqlen_in_batch_q,
947
+ max_seqlen_k=max_seqlen_in_batch_k,
948
+ dropout_p=dropout,
949
+ softmax_scale=softmax_scale,
950
+ causal=causal,
951
+ )
952
+
953
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
954
+ else:
955
+ attn_output = flash_attn_func(
956
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
957
+ )
958
+
959
+ return attn_output
960
+
961
+ def _unpad_input_packing(self, query_layer, key_layer, value_layer, attention_mask, query_length, sub_sample_lengths):
962
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data_packing(attention_mask, sub_sample_lengths)
963
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
964
+
965
+ key_layer = index_first_axis(
966
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
967
+ )
968
+ value_layer = index_first_axis(
969
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
970
+ )
971
+ if query_length == kv_seq_len:
972
+ query_layer = index_first_axis(
973
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
974
+ )
975
+ cu_seqlens_q = cu_seqlens_k
976
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
977
+ indices_q = indices_k
978
+ elif query_length == 1:
979
+ max_seqlen_in_batch_q = 1
980
+ cu_seqlens_q = torch.arange(
981
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
982
+ ) # There is a memcpy here, that is very bad.
983
+ indices_q = cu_seqlens_q[:-1]
984
+ query_layer = query_layer.squeeze(1)
985
+ else:
986
+ # The -q_len: slice assumes left padding.
987
+ attention_mask = attention_mask[:, -query_length:]
988
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
989
+
990
+ return (
991
+ query_layer,
992
+ key_layer,
993
+ value_layer,
994
+ indices_q,
995
+ (cu_seqlens_q, cu_seqlens_k),
996
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
997
+ )
998
+
999
+
1000
+
1001
+ class LlamaSdpaAttention(LlamaAttention):
1002
+ """
1003
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
1004
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
1005
+ SDPA API.
1006
+ """
1007
+
1008
+ # Adapted from LlamaAttention.forward
1009
+ def forward(
1010
+ self,
1011
+ hidden_states: torch.Tensor,
1012
+ attention_mask: Optional[torch.Tensor] = None,
1013
+ position_ids: Optional[torch.LongTensor] = None,
1014
+ past_key_value: Optional[Cache] = None,
1015
+ output_attentions: bool = False,
1016
+ use_cache: bool = False,
1017
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1018
+ if output_attentions:
1019
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
1020
+ logger.warning_once(
1021
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
1022
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
1023
+ )
1024
+ return super().forward(
1025
+ hidden_states=hidden_states,
1026
+ attention_mask=attention_mask,
1027
+ position_ids=position_ids,
1028
+ past_key_value=past_key_value,
1029
+ output_attentions=output_attentions,
1030
+ use_cache=use_cache,
1031
+ )
1032
+
1033
+ bsz, q_len, _ = hidden_states.size()
1034
+
1035
+ query_states = self.q_proj(hidden_states)
1036
+ key_states = self.k_proj(hidden_states)
1037
+ value_states = self.v_proj(hidden_states)
1038
+
1039
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1040
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1041
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1042
+
1043
+ kv_seq_len = key_states.shape[-2]
1044
+ if past_key_value is not None:
1045
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1046
+ cos, sin = self.rotary_emb(value_states, position_ids)
1047
+
1048
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1049
+
1050
+ if past_key_value is not None:
1051
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
1052
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1053
+
1054
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
1055
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
1056
+
1057
+ if attention_mask is not None:
1058
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
1059
+ raise ValueError(
1060
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
1061
+ )
1062
+
1063
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
1064
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
1065
+ if query_states.device.type == "cuda" and attention_mask is not None:
1066
+ query_states = query_states.contiguous()
1067
+ key_states = key_states.contiguous()
1068
+ value_states = value_states.contiguous()
1069
+
1070
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
1071
+ query_states,
1072
+ key_states,
1073
+ value_states,
1074
+ attn_mask=attention_mask,
1075
+ dropout_p=self.attention_dropout if self.training else 0.0,
1076
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
1077
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
1078
+ )
1079
+
1080
+ attn_output = attn_output.transpose(1, 2).contiguous()
1081
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
1082
+
1083
+ attn_output = self.o_proj(attn_output)
1084
+
1085
+ return attn_output, None, past_key_value
1086
+
1087
+
1088
+ LLAMA_ATTENTION_CLASSES = {
1089
+ "eager": LlamaAttention,
1090
+ "flash_attention_2": LlamaFlashAttention2,
1091
+ 'flash_attention_2_packing': LlamaFlashAttention2_packing,
1092
+ "sdpa": LlamaSdpaAttention,
1093
+ }
1094
+
1095
+
1096
+ class LlamaDecoderLayer(nn.Module):
1097
+ def __init__(self, config: LlamaConfig, layer_idx: int):
1098
+ super().__init__()
1099
+ self.hidden_size = config.hidden_size
1100
+
1101
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config.attn_implementation](config=config, layer_idx=layer_idx)
1102
+
1103
+ self.mlp = LlamaMLP(config)
1104
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1105
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1106
+
1107
+ def forward(
1108
+ self,
1109
+ hidden_states: torch.Tensor,
1110
+ attention_mask: Optional[torch.Tensor] = None,
1111
+ position_ids: Optional[torch.LongTensor] = None,
1112
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1113
+ sub_sample_lengths=None,
1114
+ output_attentions: Optional[bool] = False,
1115
+ use_cache: Optional[bool] = False,
1116
+ **kwargs,
1117
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1118
+ """
1119
+ Args:
1120
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1121
+ attention_mask (`torch.FloatTensor`, *optional*):
1122
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1123
+ query_sequence_length, key_sequence_length)` if default attention is used.
1124
+ output_attentions (`bool`, *optional*):
1125
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1126
+ returned tensors for more detail.
1127
+ use_cache (`bool`, *optional*):
1128
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1129
+ (see `past_key_values`).
1130
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1131
+ """
1132
+ if "padding_mask" in kwargs:
1133
+ warnings.warn(
1134
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1135
+ )
1136
+
1137
+ residual = hidden_states
1138
+
1139
+ hidden_states = self.input_layernorm(hidden_states)
1140
+
1141
+ # Self Attention
1142
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1143
+ hidden_states=hidden_states,
1144
+ attention_mask=attention_mask,
1145
+ position_ids=position_ids,
1146
+ past_key_value=past_key_value,
1147
+ output_attentions=output_attentions,
1148
+ use_cache=use_cache,
1149
+ sub_sample_lengths=sub_sample_lengths,
1150
+ **kwargs,
1151
+ )
1152
+ hidden_states = residual + hidden_states
1153
+
1154
+ # Fully Connected
1155
+ residual = hidden_states
1156
+ hidden_states = self.post_attention_layernorm(hidden_states)
1157
+ hidden_states = self.mlp(hidden_states)
1158
+ hidden_states = residual + hidden_states
1159
+
1160
+ outputs = (hidden_states,)
1161
+
1162
+ if output_attentions:
1163
+ outputs += (self_attn_weights,)
1164
+
1165
+ if use_cache:
1166
+ outputs += (present_key_value,)
1167
+
1168
+ return outputs
1169
+
1170
+
1171
+ LLAMA_START_DOCSTRING = r"""
1172
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1173
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1174
+ etc.)
1175
+
1176
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1177
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1178
+ and behavior.
1179
+
1180
+ Parameters:
1181
+ config ([`LlamaConfig`]):
1182
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1183
+ load the weights associated with the model, only the configuration. Check out the
1184
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1185
+ """
1186
+
1187
+
1188
+ @add_start_docstrings(
1189
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
1190
+ LLAMA_START_DOCSTRING,
1191
+ )
1192
+ class LlamaPreTrainedModel(PreTrainedModel):
1193
+ config_class = LlamaConfig
1194
+ base_model_prefix = "model"
1195
+ supports_gradient_checkpointing = True
1196
+ _no_split_modules = ["LlamaDecoderLayer"]
1197
+ _skip_keys_device_placement = "past_key_values"
1198
+ _supports_flash_attn_2 = True
1199
+ _supports_sdpa = True
1200
+ _supports_cache_class = True
1201
+
1202
+ def _init_weights(self, module):
1203
+ std = self.config.initializer_range
1204
+ if isinstance(module, nn.Linear):
1205
+ module.weight.data.normal_(mean=0.0, std=std)
1206
+ if module.bias is not None:
1207
+ module.bias.data.zero_()
1208
+ elif isinstance(module, nn.Embedding):
1209
+ module.weight.data.normal_(mean=0.0, std=std)
1210
+ if module.padding_idx is not None:
1211
+ module.weight.data[module.padding_idx].zero_()
1212
+
1213
+
1214
+ LLAMA_INPUTS_DOCSTRING = r"""
1215
+ Args:
1216
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1217
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1218
+ it.
1219
+
1220
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1221
+ [`PreTrainedTokenizer.__call__`] for details.
1222
+
1223
+ [What are input IDs?](../glossary#input-ids)
1224
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1225
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1226
+
1227
+ - 1 for tokens that are **not masked**,
1228
+ - 0 for tokens that are **masked**.
1229
+
1230
+ [What are attention masks?](../glossary#attention-mask)
1231
+
1232
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1233
+ [`PreTrainedTokenizer.__call__`] for details.
1234
+
1235
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1236
+ `past_key_values`).
1237
+
1238
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1239
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1240
+ information on the default strategy.
1241
+
1242
+ - 1 indicates the head is **not masked**,
1243
+ - 0 indicates the head is **masked**.
1244
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1245
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1246
+ config.n_positions - 1]`.
1247
+
1248
+ [What are position IDs?](../glossary#position-ids)
1249
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1250
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1251
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1252
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1253
+
1254
+ Two formats are allowed:
1255
+ - a [`~cache_utils.Cache`] instance;
1256
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1257
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1258
+ cache format.
1259
+
1260
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1261
+ legacy cache format will be returned.
1262
+
1263
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1264
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1265
+ of shape `(batch_size, sequence_length)`.
1266
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1267
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1268
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1269
+ model's internal embedding lookup matrix.
1270
+ use_cache (`bool`, *optional*):
1271
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1272
+ `past_key_values`).
1273
+ output_attentions (`bool`, *optional*):
1274
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1275
+ tensors for more detail.
1276
+ output_hidden_states (`bool`, *optional*):
1277
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1278
+ more detail.
1279
+ return_dict (`bool`, *optional*):
1280
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1281
+ """
1282
+
1283
+
1284
+ @add_start_docstrings(
1285
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
1286
+ LLAMA_START_DOCSTRING,
1287
+ )
1288
+ class LlamaModel(LlamaPreTrainedModel):
1289
+ """
1290
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
1291
+
1292
+ Args:
1293
+ config: LlamaConfig
1294
+ """
1295
+
1296
+ def __init__(self, config: LlamaConfig):
1297
+ super().__init__(config)
1298
+ self.padding_idx = config.pad_token_id
1299
+ self.vocab_size = config.vocab_size
1300
+
1301
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1302
+ self.layers = nn.ModuleList(
1303
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1304
+ )
1305
+ self._use_sdpa = config.attn_implementation == "sdpa"
1306
+ self.attn_implementation = config.attn_implementation
1307
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1308
+
1309
+ self.gradient_checkpointing = False
1310
+ # Initialize weights and apply final processing
1311
+ self.post_init()
1312
+
1313
+ def get_input_embeddings(self):
1314
+ return self.embed_tokens
1315
+
1316
+ def set_input_embeddings(self, value):
1317
+ self.embed_tokens = value
1318
+
1319
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1320
+ def forward(
1321
+ self,
1322
+ input_ids: torch.LongTensor = None,
1323
+ attention_mask: Optional[torch.Tensor] = None,
1324
+ position_ids: Optional[torch.LongTensor] = None,
1325
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1326
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1327
+ use_cache: Optional[bool] = None,
1328
+ output_attentions: Optional[bool] = None,
1329
+ output_hidden_states: Optional[bool] = None,
1330
+ return_dict: Optional[bool] = None,
1331
+ sub_sample_lengths=None,
1332
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1333
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1334
+ output_hidden_states = (
1335
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1336
+ )
1337
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1338
+
1339
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1340
+
1341
+ # retrieve input_ids and inputs_embeds
1342
+ if input_ids is not None and inputs_embeds is not None:
1343
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1344
+ elif input_ids is not None:
1345
+ batch_size, seq_length = input_ids.shape[:2]
1346
+ elif inputs_embeds is not None:
1347
+ batch_size, seq_length = inputs_embeds.shape[:2]
1348
+ else:
1349
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1350
+
1351
+ if self.gradient_checkpointing and self.training:
1352
+ if use_cache:
1353
+ logger.warning_once(
1354
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1355
+ )
1356
+ use_cache = False
1357
+
1358
+ past_key_values_length = 0
1359
+ if use_cache:
1360
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1361
+ if use_legacy_cache:
1362
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1363
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1364
+
1365
+ if position_ids is None:
1366
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1367
+ position_ids = torch.arange(
1368
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1369
+ )
1370
+ position_ids = position_ids.unsqueeze(0)
1371
+
1372
+ if inputs_embeds is None:
1373
+ inputs_embeds = self.embed_tokens(input_ids)
1374
+
1375
+ if self.attn_implementation == "flash_attention_2" or self.config.attn_implementation =='flash_attention_2_packing':
1376
+ # 2d mask is passed through the layers
1377
+ # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1378
+ if attention_mask is not None:
1379
+ if attention_mask.dtype == torch.long:
1380
+ pass
1381
+ else:
1382
+ attention_mask = attention_mask if (0 in attention_mask) else None
1383
+ elif self._use_sdpa and not output_attentions:
1384
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1385
+ # the manual implementation that requires a 4D causal mask in all cases.
1386
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1387
+ attention_mask,
1388
+ (batch_size, seq_length),
1389
+ inputs_embeds,
1390
+ past_key_values_length,
1391
+ )
1392
+ else:
1393
+ # 4d mask is passed through the layers
1394
+ attention_mask = _prepare_4d_causal_attention_mask(
1395
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1396
+ )
1397
+
1398
+ # embed positions
1399
+ hidden_states = inputs_embeds
1400
+
1401
+ # decoder layers
1402
+ all_hidden_states = () if output_hidden_states else None
1403
+ all_self_attns = () if output_attentions else None
1404
+ next_decoder_cache = None
1405
+
1406
+ for decoder_layer in self.layers:
1407
+ if output_hidden_states:
1408
+ all_hidden_states += (hidden_states,)
1409
+
1410
+ if self.gradient_checkpointing and self.training:
1411
+ layer_outputs = self._gradient_checkpointing_func(
1412
+ decoder_layer.__call__,
1413
+ hidden_states,
1414
+ attention_mask,
1415
+ position_ids,
1416
+ past_key_values,
1417
+ sub_sample_lengths,
1418
+ output_attentions,
1419
+ use_cache,
1420
+ )
1421
+ else:
1422
+ layer_outputs = decoder_layer(
1423
+ hidden_states,
1424
+ attention_mask=attention_mask,
1425
+ position_ids=position_ids,
1426
+ past_key_value=past_key_values,
1427
+ sub_sample_lengths=sub_sample_lengths,
1428
+ output_attentions=output_attentions,
1429
+ use_cache=use_cache,
1430
+ )
1431
+
1432
+ hidden_states = layer_outputs[0]
1433
+
1434
+ if use_cache:
1435
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1436
+
1437
+ if output_attentions:
1438
+ all_self_attns += (layer_outputs[1],)
1439
+
1440
+ hidden_states = self.norm(hidden_states)
1441
+
1442
+ # add hidden states from the last decoder layer
1443
+ if output_hidden_states:
1444
+ all_hidden_states += (hidden_states,)
1445
+
1446
+ next_cache = None
1447
+ if use_cache:
1448
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1449
+ if not return_dict:
1450
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1451
+ return BaseModelOutputWithPast(
1452
+ last_hidden_state=hidden_states,
1453
+ past_key_values=next_cache,
1454
+ hidden_states=all_hidden_states,
1455
+ attentions=all_self_attns,
1456
+ )
1457
+
1458
+
1459
+ class LlamaForCausalLM(LlamaPreTrainedModel):
1460
+ _tied_weights_keys = ["lm_head.weight"]
1461
+
1462
+ def __init__(self, config):
1463
+ super().__init__(config)
1464
+ self.model = LlamaModel(config)
1465
+ self.vocab_size = config.vocab_size
1466
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1467
+
1468
+ # Initialize weights and apply final processing
1469
+ self.post_init()
1470
+ self.support_packing = True
1471
+
1472
+ def get_input_embeddings(self):
1473
+ return self.model.embed_tokens
1474
+
1475
+ def set_input_embeddings(self, value):
1476
+ self.model.embed_tokens = value
1477
+
1478
+ def get_output_embeddings(self):
1479
+ return self.lm_head
1480
+
1481
+ def set_output_embeddings(self, new_embeddings):
1482
+ self.lm_head = new_embeddings
1483
+
1484
+ def set_decoder(self, decoder):
1485
+ self.model = decoder
1486
+
1487
+ def get_decoder(self):
1488
+ return self.model
1489
+
1490
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1491
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1492
+ def forward(
1493
+ self,
1494
+ input_ids: torch.LongTensor = None,
1495
+ attention_mask: Optional[torch.Tensor] = None,
1496
+ position_ids: Optional[torch.LongTensor] = None,
1497
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1498
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1499
+ labels: Optional[torch.LongTensor] = None,
1500
+ use_cache: Optional[bool] = None,
1501
+ output_attentions: Optional[bool] = None,
1502
+ output_hidden_states: Optional[bool] = None,
1503
+ return_dict: Optional[bool] = None,
1504
+ sub_sample_lengths=None,
1505
+ efficient_loss=False,
1506
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1507
+ r"""
1508
+ Args:
1509
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1510
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1511
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1512
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1513
+
1514
+ Returns:
1515
+
1516
+ Example:
1517
+
1518
+ ```python
1519
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1520
+
1521
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
1522
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
1523
+
1524
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1525
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1526
+
1527
+ >>> # Generate
1528
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1529
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1530
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1531
+ ```"""
1532
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1533
+ output_hidden_states = (
1534
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1535
+ )
1536
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1537
+
1538
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1539
+ outputs = self.model(
1540
+ input_ids=input_ids,
1541
+ attention_mask=attention_mask,
1542
+ position_ids=position_ids,
1543
+ past_key_values=past_key_values,
1544
+ inputs_embeds=inputs_embeds,
1545
+ use_cache=use_cache,
1546
+ output_attentions=output_attentions,
1547
+ output_hidden_states=output_hidden_states,
1548
+ return_dict=return_dict,
1549
+ sub_sample_lengths=sub_sample_lengths
1550
+ )
1551
+ if efficient_loss:
1552
+ return outputs, self.lm_head.weight
1553
+
1554
+ hidden_states = outputs[0]
1555
+ if self.config.pretraining_tp > 1:
1556
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1557
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1558
+ logits = torch.cat(logits, dim=-1)
1559
+ else:
1560
+ logits = self.lm_head(hidden_states)
1561
+ logits = logits.float()
1562
+
1563
+ loss = None
1564
+ if labels is not None:
1565
+ # Shift so that tokens < n predict n
1566
+ shift_logits = logits[..., :-1, :].contiguous()
1567
+ shift_labels = labels[..., 1:].contiguous()
1568
+ # Flatten the tokens
1569
+ loss_fct = CrossEntropyLoss()
1570
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1571
+ shift_labels = shift_labels.view(-1)
1572
+ # Enable model parallelism
1573
+ shift_labels = shift_labels.to(shift_logits.device)
1574
+ loss = loss_fct(shift_logits, shift_labels)
1575
+
1576
+ if not return_dict:
1577
+ output = (logits,) + outputs[1:]
1578
+ return (loss,) + output if loss is not None else output
1579
+
1580
+ return CausalLMOutputWithPast(
1581
+ loss=loss,
1582
+ logits=logits,
1583
+ past_key_values=outputs.past_key_values,
1584
+ hidden_states=outputs.hidden_states,
1585
+ attentions=outputs.attentions,
1586
+ )
1587
+
1588
+ def prepare_inputs_for_generation(
1589
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1590
+ ):
1591
+ if past_key_values is not None:
1592
+ if isinstance(past_key_values, Cache):
1593
+ cache_length = past_key_values.get_seq_length()
1594
+ past_length = past_key_values.seen_tokens
1595
+ max_cache_length = past_key_values.get_max_length()
1596
+ else:
1597
+ cache_length = past_length = past_key_values[0][0].shape[2]
1598
+ max_cache_length = None
1599
+
1600
+ # Keep only the unprocessed tokens:
1601
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1602
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1603
+ # input)
1604
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1605
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1606
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1607
+ # input_ids based on the past_length.
1608
+ elif past_length < input_ids.shape[1]:
1609
+ input_ids = input_ids[:, past_length:]
1610
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1611
+
1612
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1613
+ if (
1614
+ max_cache_length is not None
1615
+ and attention_mask is not None
1616
+ and cache_length + input_ids.shape[1] > max_cache_length
1617
+ ):
1618
+ attention_mask = attention_mask[:, -max_cache_length:]
1619
+
1620
+ position_ids = kwargs.get("position_ids", None)
1621
+ if attention_mask is not None and position_ids is None:
1622
+ # create position_ids on the fly for batch generation
1623
+ position_ids = attention_mask.long().cumsum(-1) - 1
1624
+ position_ids.masked_fill_(attention_mask == 0, 1)
1625
+ if past_key_values:
1626
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1627
+
1628
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1629
+ if inputs_embeds is not None and past_key_values is None:
1630
+ model_inputs = {"inputs_embeds": inputs_embeds}
1631
+ else:
1632
+ model_inputs = {"input_ids": input_ids}
1633
+
1634
+ model_inputs.update(
1635
+ {
1636
+ "position_ids": position_ids,
1637
+ "past_key_values": past_key_values,
1638
+ "use_cache": kwargs.get("use_cache"),
1639
+ "attention_mask": attention_mask,
1640
+ }
1641
+ )
1642
+ return model_inputs
1643
+
1644
+ @staticmethod
1645
+ def _reorder_cache(past_key_values, beam_idx):
1646
+ reordered_past = ()
1647
+ for layer_past in past_key_values:
1648
+ reordered_past += (
1649
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1650
+ )
1651
+ return reordered_past
1652
+
1653
+
1654
+ @add_start_docstrings(
1655
+ """
1656
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
1657
+
1658
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1659
+ (e.g. GPT-2) do.
1660
+
1661
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1662
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1663
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1664
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1665
+ each row of the batch).
1666
+ """,
1667
+ LLAMA_START_DOCSTRING,
1668
+ )
1669
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1670
+ def __init__(self, config):
1671
+ super().__init__(config)
1672
+ self.num_labels = config.num_labels
1673
+ self.model = LlamaModel(config)
1674
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1675
+
1676
+ # Initialize weights and apply final processing
1677
+ self.post_init()
1678
+
1679
+ def get_input_embeddings(self):
1680
+ return self.model.embed_tokens
1681
+
1682
+ def set_input_embeddings(self, value):
1683
+ self.model.embed_tokens = value
1684
+
1685
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1686
+ def forward(
1687
+ self,
1688
+ input_ids: torch.LongTensor = None,
1689
+ attention_mask: Optional[torch.Tensor] = None,
1690
+ position_ids: Optional[torch.LongTensor] = None,
1691
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1692
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1693
+ labels: Optional[torch.LongTensor] = None,
1694
+ use_cache: Optional[bool] = None,
1695
+ output_attentions: Optional[bool] = None,
1696
+ output_hidden_states: Optional[bool] = None,
1697
+ return_dict: Optional[bool] = None,
1698
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1699
+ r"""
1700
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1701
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1702
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1703
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1704
+ """
1705
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1706
+
1707
+ transformer_outputs = self.model(
1708
+ input_ids,
1709
+ attention_mask=attention_mask,
1710
+ position_ids=position_ids,
1711
+ past_key_values=past_key_values,
1712
+ inputs_embeds=inputs_embeds,
1713
+ use_cache=use_cache,
1714
+ output_attentions=output_attentions,
1715
+ output_hidden_states=output_hidden_states,
1716
+ return_dict=return_dict,
1717
+ )
1718
+ hidden_states = transformer_outputs[0]
1719
+ logits = self.score(hidden_states)
1720
+
1721
+ if input_ids is not None:
1722
+ batch_size = input_ids.shape[0]
1723
+ else:
1724
+ batch_size = inputs_embeds.shape[0]
1725
+
1726
+ if self.config.pad_token_id is None and batch_size != 1:
1727
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1728
+ if self.config.pad_token_id is None:
1729
+ sequence_lengths = -1
1730
+ else:
1731
+ if input_ids is not None:
1732
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1733
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1734
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1735
+ sequence_lengths = sequence_lengths.to(logits.device)
1736
+ else:
1737
+ sequence_lengths = -1
1738
+
1739
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1740
+
1741
+ loss = None
1742
+ if labels is not None:
1743
+ labels = labels.to(logits.device)
1744
+ if self.config.problem_type is None:
1745
+ if self.num_labels == 1:
1746
+ self.config.problem_type = "regression"
1747
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1748
+ self.config.problem_type = "single_label_classification"
1749
+ else:
1750
+ self.config.problem_type = "multi_label_classification"
1751
+
1752
+ if self.config.problem_type == "regression":
1753
+ loss_fct = MSELoss()
1754
+ if self.num_labels == 1:
1755
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1756
+ else:
1757
+ loss = loss_fct(pooled_logits, labels)
1758
+ elif self.config.problem_type == "single_label_classification":
1759
+ loss_fct = CrossEntropyLoss()
1760
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1761
+ elif self.config.problem_type == "multi_label_classification":
1762
+ loss_fct = BCEWithLogitsLoss()
1763
+ loss = loss_fct(pooled_logits, labels)
1764
+ if not return_dict:
1765
+ output = (pooled_logits,) + transformer_outputs[1:]
1766
+ return ((loss,) + output) if loss is not None else output
1767
+
1768
+ return SequenceClassifierOutputWithPast(
1769
+ loss=loss,
1770
+ logits=pooled_logits,
1771
+ past_key_values=transformer_outputs.past_key_values,
1772
+ hidden_states=transformer_outputs.hidden_states,
1773
+ attentions=transformer_outputs.attentions,
1774
+ )
modeling_qwen2.py ADDED
@@ -0,0 +1,1744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch Qwen2 model."""
21
+ import inspect
22
+ import math
23
+ import warnings
24
+ from typing import List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ from torch import nn
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache
34
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
35
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.utils import (
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ is_flash_attn_2_available,
41
+ is_flash_attn_greater_or_equal_2_10,
42
+ logging,
43
+ replace_return_docstrings,
44
+ )
45
+ from .configuration_qwen2 import Qwen2Config
46
+
47
+
48
+ if is_flash_attn_2_available():
49
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
50
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
51
+
52
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+
58
+ _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
59
+ _CONFIG_FOR_DOC = "Qwen2Config"
60
+
61
+ QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [
62
+ "Qwen/Qwen2-7B-beta",
63
+ # See all Qwen2 models at https://huggingface.co/models?filter=qwen2
64
+ ]
65
+
66
+
67
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
68
+ def _get_unpad_data(attention_mask):
69
+ seqlens_in_batch = (attention_mask>0).sum(dim=-1, dtype=torch.int32)
70
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
71
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
72
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
73
+ return (
74
+ indices,
75
+ cu_seqlens,
76
+ max_seqlen_in_batch,
77
+ )
78
+
79
+ def _get_unpad_data_packing(attention_mask, sub_sample_lengths):
80
+ seqlens_in_batch = []
81
+ for i, per_sub_sample_lengths in enumerate(sub_sample_lengths):
82
+ if (attention_mask[i]==0).sum() == per_sub_sample_lengths[-1]:
83
+ per_sub_sample_lengths = per_sub_sample_lengths[:-1]
84
+ seqlens_in_batch.extend(per_sub_sample_lengths)
85
+ seqlens_in_batch = torch.tensor(seqlens_in_batch, device=attention_mask.device, dtype=torch.int32)
86
+
87
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
88
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
89
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
90
+ return (
91
+ indices,
92
+ cu_seqlens,
93
+ max_seqlen_in_batch,
94
+ )
95
+
96
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
97
+ class Qwen2RMSNorm(nn.Module):
98
+ def __init__(self, hidden_size, eps=1e-6):
99
+ """
100
+ Qwen2RMSNorm is equivalent to T5LayerNorm
101
+ """
102
+ super().__init__()
103
+ self.weight = nn.Parameter(torch.ones(hidden_size))
104
+ self.variance_epsilon = eps
105
+
106
+ def forward(self, hidden_states):
107
+ input_dtype = hidden_states.dtype
108
+ hidden_states = hidden_states.to(torch.float32)
109
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
110
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
111
+ return self.weight * hidden_states.to(input_dtype)
112
+
113
+
114
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
115
+ class Qwen2RotaryEmbedding(nn.Module):
116
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
117
+ super().__init__()
118
+
119
+ self.dim = dim
120
+ self.max_position_embeddings = max_position_embeddings
121
+ self.base = base
122
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
123
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
124
+
125
+ # Build here to make `torch.jit.trace` work.
126
+ self._set_cos_sin_cache(
127
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
128
+ )
129
+
130
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
131
+ self.max_seq_len_cached = seq_len
132
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
133
+
134
+ freqs = torch.outer(t, self.inv_freq)
135
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
136
+ emb = torch.cat((freqs, freqs), dim=-1)
137
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
138
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
139
+
140
+ def forward(self, x, seq_len=None):
141
+ # x: [bs, num_attention_heads, seq_len, head_size]
142
+ if seq_len > self.max_seq_len_cached:
143
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
144
+
145
+ return (
146
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
147
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
148
+ )
149
+
150
+
151
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
152
+ def rotate_half(x):
153
+ """Rotates half the hidden dims of the input."""
154
+ x1 = x[..., : x.shape[-1] // 2]
155
+ x2 = x[..., x.shape[-1] // 2 :]
156
+ return torch.cat((-x2, x1), dim=-1)
157
+
158
+
159
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
160
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
161
+ """Applies Rotary Position Embedding to the query and key tensors.
162
+
163
+ Args:
164
+ q (`torch.Tensor`): The query tensor.
165
+ k (`torch.Tensor`): The key tensor.
166
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
167
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
168
+ position_ids (`torch.Tensor`):
169
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
170
+ used to pass offsetted position ids when working with a KV-cache.
171
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
172
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
173
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
174
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
175
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
176
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
177
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
178
+ Returns:
179
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
180
+ """
181
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
182
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
183
+ q_embed = (q * cos) + (rotate_half(q) * sin)
184
+ k_embed = (k * cos) + (rotate_half(k) * sin)
185
+ return q_embed, k_embed
186
+
187
+
188
+ # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
189
+ class Qwen2MLP(nn.Module):
190
+ def __init__(self, config):
191
+ super().__init__()
192
+ self.config = config
193
+ self.hidden_size = config.hidden_size
194
+ self.intermediate_size = config.intermediate_size
195
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
196
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
197
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
198
+ self.act_fn = ACT2FN[config.hidden_act]
199
+
200
+ def forward(self, x):
201
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
202
+
203
+
204
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
205
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
206
+ """
207
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
208
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
209
+ """
210
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
211
+ if n_rep == 1:
212
+ return hidden_states
213
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
214
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
215
+
216
+
217
+ class Qwen2Attention(nn.Module):
218
+ """
219
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
220
+ and "Generating Long Sequences with Sparse Transformers".
221
+ """
222
+
223
+ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
224
+ super().__init__()
225
+ self.config = config
226
+ self.layer_idx = layer_idx
227
+ if layer_idx is None:
228
+ logger.warning_once(
229
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
230
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
231
+ "when creating this class."
232
+ )
233
+
234
+ self.hidden_size = config.hidden_size
235
+ self.num_heads = config.num_attention_heads
236
+ self.head_dim = self.hidden_size // self.num_heads
237
+ self.num_key_value_heads = config.num_key_value_heads
238
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
239
+ self.max_position_embeddings = config.max_position_embeddings
240
+ self.rope_theta = config.rope_theta
241
+ self.is_causal = True
242
+ self.attention_dropout = config.attention_dropout
243
+
244
+ if (self.head_dim * self.num_heads) != self.hidden_size:
245
+ raise ValueError(
246
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
247
+ f" and `num_heads`: {self.num_heads})."
248
+ )
249
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
250
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
251
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
252
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
253
+
254
+ self.rotary_emb = Qwen2RotaryEmbedding(
255
+ self.head_dim,
256
+ max_position_embeddings=self.max_position_embeddings,
257
+ base=self.rope_theta,
258
+ )
259
+
260
+ def forward(
261
+ self,
262
+ hidden_states: torch.Tensor,
263
+ attention_mask: Optional[torch.Tensor] = None,
264
+ position_ids: Optional[torch.LongTensor] = None,
265
+ past_key_value: Optional[Cache] = None,
266
+ output_attentions: bool = False,
267
+ use_cache: bool = False,
268
+ **kwargs,
269
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
270
+ if "padding_mask" in kwargs:
271
+ warnings.warn(
272
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
273
+ )
274
+ bsz, q_len, _ = hidden_states.size()
275
+
276
+ query_states = self.q_proj(hidden_states)
277
+ key_states = self.k_proj(hidden_states)
278
+ value_states = self.v_proj(hidden_states)
279
+
280
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
281
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
282
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
283
+
284
+ kv_seq_len = key_states.shape[-2]
285
+ if past_key_value is not None:
286
+ if self.layer_idx is None:
287
+ raise ValueError(
288
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
289
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
290
+ "with a layer index."
291
+ )
292
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
293
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
294
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
295
+
296
+ if past_key_value is not None:
297
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
298
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
299
+
300
+ # repeat k/v heads if n_kv_heads < n_heads
301
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
302
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
303
+
304
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
305
+
306
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
307
+ raise ValueError(
308
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
309
+ f" {attn_weights.size()}"
310
+ )
311
+
312
+ if attention_mask is not None:
313
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
314
+ raise ValueError(
315
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
316
+ )
317
+
318
+ attn_weights = attn_weights + attention_mask
319
+
320
+ # upcast attention to fp32
321
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
322
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
323
+ attn_output = torch.matmul(attn_weights, value_states)
324
+
325
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
326
+ raise ValueError(
327
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
328
+ f" {attn_output.size()}"
329
+ )
330
+
331
+ attn_output = attn_output.transpose(1, 2).contiguous()
332
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
333
+
334
+ attn_output = self.o_proj(attn_output)
335
+
336
+ if not output_attentions:
337
+ attn_weights = None
338
+
339
+ return attn_output, attn_weights, past_key_value
340
+
341
+
342
+ class Qwen2FlashAttention2(Qwen2Attention):
343
+ """
344
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
345
+ as the weights of the module stays untouched. The only required change would be on the forward pass
346
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
347
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
348
+ config.max_window_layers layers.
349
+ """
350
+
351
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
352
+ def __init__(self, *args, **kwargs):
353
+ super().__init__(*args, **kwargs)
354
+
355
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
356
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
357
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
358
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
359
+
360
+ def forward(
361
+ self,
362
+ hidden_states: torch.Tensor,
363
+ attention_mask: Optional[torch.Tensor] = None,
364
+ position_ids: Optional[torch.LongTensor] = None,
365
+ past_key_value: Optional[Cache] = None,
366
+ output_attentions: bool = False,
367
+ use_cache: bool = False,
368
+ **kwargs,
369
+ ):
370
+ if "padding_mask" in kwargs:
371
+ warnings.warn(
372
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
373
+ )
374
+
375
+ # overwrite attention_mask with padding_mask
376
+ attention_mask = kwargs.pop("padding_mask")
377
+ bsz, q_len, _ = hidden_states.size()
378
+
379
+ query_states = self.q_proj(hidden_states)
380
+ key_states = self.k_proj(hidden_states)
381
+ value_states = self.v_proj(hidden_states)
382
+
383
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
384
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
385
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
386
+
387
+ kv_seq_len = key_states.shape[-2]
388
+ if past_key_value is not None:
389
+ if self.layer_idx is None:
390
+ raise ValueError(
391
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
392
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
393
+ "with a layer index."
394
+ )
395
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
396
+
397
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
398
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
399
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
400
+
401
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
402
+
403
+ use_sliding_windows = (
404
+ _flash_supports_window_size
405
+ and getattr(self.config, "sliding_window", None) is not None
406
+ and kv_seq_len > self.config.sliding_window
407
+ and self.config.use_sliding_window
408
+ )
409
+
410
+ if not _flash_supports_window_size:
411
+ logger.warning_once(
412
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
413
+ " make sure to upgrade flash-attn library."
414
+ )
415
+
416
+ if past_key_value is not None:
417
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
418
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
419
+ if (
420
+ getattr(self.config, "sliding_window", None) is not None
421
+ and kv_seq_len > self.config.sliding_window
422
+ and cache_has_contents
423
+ ):
424
+ slicing_tokens = 1 - self.config.sliding_window
425
+
426
+ past_key = past_key_value[self.layer_idx][0]
427
+ past_value = past_key_value[self.layer_idx][1]
428
+
429
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
430
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
431
+
432
+ if past_key.shape[-2] != self.config.sliding_window - 1:
433
+ raise ValueError(
434
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
435
+ f" {past_key.shape}"
436
+ )
437
+
438
+ if attention_mask is not None:
439
+ attention_mask = attention_mask[:, slicing_tokens:]
440
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
441
+
442
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
443
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
444
+
445
+ # repeat k/v heads if n_kv_heads < n_heads
446
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
447
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
448
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
449
+
450
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
451
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
452
+ # cast them back in float16 just to be sure everything works as expected.
453
+ input_dtype = query_states.dtype
454
+ if input_dtype == torch.float32:
455
+ if torch.is_autocast_enabled():
456
+ target_dtype = torch.get_autocast_gpu_dtype()
457
+ # Handle the case where the model is quantized
458
+ elif hasattr(self.config, "_pre_quantization_dtype"):
459
+ target_dtype = self.config._pre_quantization_dtype
460
+ else:
461
+ target_dtype = self.q_proj.weight.dtype
462
+
463
+ logger.warning_once(
464
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
465
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
466
+ f" {target_dtype}."
467
+ )
468
+
469
+ query_states = query_states.to(target_dtype)
470
+ key_states = key_states.to(target_dtype)
471
+ value_states = value_states.to(target_dtype)
472
+
473
+ # Reashape to the expected shape for Flash Attention
474
+ query_states = query_states.transpose(1, 2)
475
+ key_states = key_states.transpose(1, 2)
476
+ value_states = value_states.transpose(1, 2)
477
+
478
+ attn_output = self._flash_attention_forward(
479
+ query_states,
480
+ key_states,
481
+ value_states,
482
+ attention_mask,
483
+ q_len,
484
+ dropout=dropout_rate,
485
+ use_sliding_windows=use_sliding_windows,
486
+ )
487
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
488
+ attn_output = self.o_proj(attn_output)
489
+
490
+ if not output_attentions:
491
+ attn_weights = None
492
+
493
+ return attn_output, attn_weights, past_key_value
494
+
495
+ def _flash_attention_forward(
496
+ self,
497
+ query_states,
498
+ key_states,
499
+ value_states,
500
+ attention_mask,
501
+ query_length,
502
+ dropout=0.0,
503
+ softmax_scale=None,
504
+ use_sliding_windows=False,
505
+ ):
506
+ """
507
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
508
+ first unpad the input, then computes the attention scores and pad the final attention scores.
509
+
510
+ Args:
511
+ query_states (`torch.Tensor`):
512
+ Input query states to be passed to Flash Attention API
513
+ key_states (`torch.Tensor`):
514
+ Input key states to be passed to Flash Attention API
515
+ value_states (`torch.Tensor`):
516
+ Input value states to be passed to Flash Attention API
517
+ attention_mask (`torch.Tensor`):
518
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
519
+ position of padding tokens and 1 for the position of non-padding tokens.
520
+ dropout (`int`, *optional*):
521
+ Attention dropout
522
+ softmax_scale (`float`, *optional*):
523
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
524
+ use_sliding_windows (`bool`, *optional*):
525
+ Whether to activate sliding window attention.
526
+ """
527
+ if not self._flash_attn_uses_top_left_mask:
528
+ causal = self.is_causal
529
+ else:
530
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
531
+ causal = self.is_causal and query_length != 1
532
+
533
+ # Decide whether to use SWA or not by layer index.
534
+ if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
535
+ use_sliding_windows = False
536
+
537
+ # Contains at least one padding token in the sequence
538
+ if attention_mask is not None:
539
+ batch_size = query_states.shape[0]
540
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
541
+ query_states, key_states, value_states, attention_mask, query_length
542
+ )
543
+
544
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
545
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
546
+
547
+ if not use_sliding_windows:
548
+ attn_output_unpad = flash_attn_varlen_func(
549
+ query_states,
550
+ key_states,
551
+ value_states,
552
+ cu_seqlens_q=cu_seqlens_q,
553
+ cu_seqlens_k=cu_seqlens_k,
554
+ max_seqlen_q=max_seqlen_in_batch_q,
555
+ max_seqlen_k=max_seqlen_in_batch_k,
556
+ dropout_p=dropout,
557
+ softmax_scale=softmax_scale,
558
+ causal=causal,
559
+ )
560
+ else:
561
+ attn_output_unpad = flash_attn_varlen_func(
562
+ query_states,
563
+ key_states,
564
+ value_states,
565
+ cu_seqlens_q=cu_seqlens_q,
566
+ cu_seqlens_k=cu_seqlens_k,
567
+ max_seqlen_q=max_seqlen_in_batch_q,
568
+ max_seqlen_k=max_seqlen_in_batch_k,
569
+ dropout_p=dropout,
570
+ softmax_scale=softmax_scale,
571
+ causal=causal,
572
+ window_size=(self.config.sliding_window, self.config.sliding_window),
573
+ )
574
+
575
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
576
+ else:
577
+ if not use_sliding_windows:
578
+ attn_output = flash_attn_func(
579
+ query_states,
580
+ key_states,
581
+ value_states,
582
+ dropout,
583
+ softmax_scale=softmax_scale,
584
+ causal=causal,
585
+ )
586
+ else:
587
+ attn_output = flash_attn_func(
588
+ query_states,
589
+ key_states,
590
+ value_states,
591
+ dropout,
592
+ softmax_scale=softmax_scale,
593
+ causal=causal,
594
+ window_size=(self.config.sliding_window, self.config.sliding_window),
595
+ )
596
+
597
+ return attn_output
598
+
599
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
600
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
601
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
602
+
603
+ # On the first iteration we need to properly re-create the padding mask
604
+ # by slicing it on the proper place
605
+ if kv_seq_len != attention_mask.shape[-1]:
606
+ attention_mask_num_tokens = attention_mask.shape[-1]
607
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
608
+
609
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
610
+
611
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
612
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
613
+
614
+ if query_length == kv_seq_len:
615
+ query_layer = index_first_axis(
616
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
617
+ )
618
+ cu_seqlens_q = cu_seqlens_k
619
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
620
+ indices_q = indices_k
621
+ elif query_length == 1:
622
+ max_seqlen_in_batch_q = 1
623
+ cu_seqlens_q = torch.arange(
624
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
625
+ ) # There is a memcpy here, that is very bad.
626
+ indices_q = cu_seqlens_q[:-1]
627
+ query_layer = query_layer.squeeze(1)
628
+ else:
629
+ # The -q_len: slice assumes left padding.
630
+ attention_mask = attention_mask[:, -query_length:]
631
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
632
+
633
+ return (
634
+ query_layer,
635
+ key_layer,
636
+ value_layer,
637
+ indices_q,
638
+ (cu_seqlens_q, cu_seqlens_k),
639
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
640
+ )
641
+ class Qwen2FlashAttention2_packing(Qwen2Attention):
642
+ """
643
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
644
+ as the weights of the module stays untouched. The only required change would be on the forward pass
645
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
646
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
647
+ config.max_window_layers layers.
648
+ """
649
+
650
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
651
+ def __init__(self, *args, **kwargs):
652
+ super().__init__(*args, **kwargs)
653
+
654
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
655
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
656
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
657
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
658
+
659
+ def forward(
660
+ self,
661
+ hidden_states: torch.Tensor,
662
+ attention_mask: Optional[torch.Tensor] = None,
663
+ position_ids: Optional[torch.LongTensor] = None,
664
+ past_key_value: Optional[Cache] = None,
665
+ output_attentions: bool = False,
666
+ use_cache: bool = False,
667
+ sub_sample_lengths = None,
668
+ **kwargs,
669
+ ):
670
+ if "padding_mask" in kwargs:
671
+ warnings.warn(
672
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
673
+ )
674
+
675
+ # overwrite attention_mask with padding_mask
676
+ attention_mask = kwargs.pop("padding_mask")
677
+ bsz, q_len, _ = hidden_states.size()
678
+
679
+ query_states = self.q_proj(hidden_states)
680
+ key_states = self.k_proj(hidden_states)
681
+ value_states = self.v_proj(hidden_states)
682
+
683
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
684
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
685
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
686
+
687
+ kv_seq_len = key_states.shape[-2]
688
+ if past_key_value is not None:
689
+ if self.layer_idx is None:
690
+ raise ValueError(
691
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
692
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
693
+ "with a layer index."
694
+ )
695
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
696
+
697
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
698
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
699
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
700
+
701
+ if sub_sample_lengths is not None:
702
+ packing_position_ids = []
703
+ for b in range(bsz):
704
+ each_sum_sample_lengths = sub_sample_lengths[b]
705
+ packing_position_ids.append(torch.cat([torch.arange(each) for each in each_sum_sample_lengths]))
706
+ packing_position_ids = torch.stack(packing_position_ids)
707
+ packing_position_ids.to(query_states.device)
708
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, packing_position_ids)
709
+ else:
710
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
711
+
712
+ use_sliding_windows = (
713
+ _flash_supports_window_size
714
+ and getattr(self.config, "sliding_window", None) is not None
715
+ and kv_seq_len > self.config.sliding_window
716
+ and self.config.use_sliding_window
717
+ )
718
+
719
+ if not _flash_supports_window_size:
720
+ logger.warning_once(
721
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
722
+ " make sure to upgrade flash-attn library."
723
+ )
724
+
725
+ if past_key_value is not None:
726
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
727
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
728
+ if (
729
+ getattr(self.config, "sliding_window", None) is not None
730
+ and kv_seq_len > self.config.sliding_window
731
+ and cache_has_contents
732
+ ):
733
+ slicing_tokens = 1 - self.config.sliding_window
734
+
735
+ past_key = past_key_value[self.layer_idx][0]
736
+ past_value = past_key_value[self.layer_idx][1]
737
+
738
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
739
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
740
+
741
+ if past_key.shape[-2] != self.config.sliding_window - 1:
742
+ raise ValueError(
743
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
744
+ f" {past_key.shape}"
745
+ )
746
+
747
+ if attention_mask is not None:
748
+ attention_mask = attention_mask[:, slicing_tokens:]
749
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
750
+
751
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
752
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
753
+
754
+ # repeat k/v heads if n_kv_heads < n_heads
755
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
756
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
757
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
758
+
759
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
760
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
761
+ # cast them back in float16 just to be sure everything works as expected.
762
+ input_dtype = query_states.dtype
763
+ if input_dtype == torch.float32:
764
+ if torch.is_autocast_enabled():
765
+ target_dtype = torch.get_autocast_gpu_dtype()
766
+ # Handle the case where the model is quantized
767
+ elif hasattr(self.config, "_pre_quantization_dtype"):
768
+ target_dtype = self.config._pre_quantization_dtype
769
+ else:
770
+ target_dtype = self.q_proj.weight.dtype
771
+
772
+ logger.warning_once(
773
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
774
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
775
+ f" {target_dtype}."
776
+ )
777
+
778
+ query_states = query_states.to(target_dtype)
779
+ key_states = key_states.to(target_dtype)
780
+ value_states = value_states.to(target_dtype)
781
+
782
+ # Reashape to the expected shape for Flash Attention
783
+ query_states = query_states.transpose(1, 2)
784
+ key_states = key_states.transpose(1, 2)
785
+ value_states = value_states.transpose(1, 2)
786
+
787
+ attn_output = self._flash_attention_forward(
788
+ query_states,
789
+ key_states,
790
+ value_states,
791
+ attention_mask,
792
+ q_len,
793
+ dropout=dropout_rate,
794
+ use_sliding_windows=use_sliding_windows,
795
+ sub_sample_lengths=sub_sample_lengths
796
+ )
797
+
798
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
799
+ attn_output = self.o_proj(attn_output)
800
+
801
+ if not output_attentions:
802
+ attn_weights = None
803
+
804
+ return attn_output, attn_weights, past_key_value
805
+
806
+ def _flash_attention_forward(
807
+ self,
808
+ query_states,
809
+ key_states,
810
+ value_states,
811
+ attention_mask,
812
+ query_length,
813
+ dropout=0.0,
814
+ softmax_scale=None,
815
+ use_sliding_windows=False,
816
+ sub_sample_lengths=None,
817
+ ):
818
+ """
819
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
820
+ first unpad the input, then computes the attention scores and pad the final attention scores.
821
+
822
+ Args:
823
+ query_states (`torch.Tensor`):
824
+ Input query states to be passed to Flash Attention API
825
+ key_states (`torch.Tensor`):
826
+ Input key states to be passed to Flash Attention API
827
+ value_states (`torch.Tensor`):
828
+ Input value states to be passed to Flash Attention API
829
+ attention_mask (`torch.Tensor`):
830
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
831
+ position of padding tokens and 1 for the position of non-padding tokens.
832
+ dropout (`int`, *optional*):
833
+ Attention dropout
834
+ softmax_scale (`float`, *optional*):
835
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
836
+ use_sliding_windows (`bool`, *optional*):
837
+ Whether to activate sliding window attention.
838
+ """
839
+ if not self._flash_attn_uses_top_left_mask:
840
+ causal = self.is_causal
841
+ else:
842
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
843
+ causal = self.is_causal and query_length != 1
844
+
845
+ # Decide whether to use SWA or not by layer index.
846
+ if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
847
+ use_sliding_windows = False
848
+
849
+ # Contains at least one padding token in the sequence
850
+
851
+ if attention_mask is not None:
852
+ batch_size = query_states.shape[0]
853
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input_packing(
854
+ query_states, key_states, value_states, attention_mask, query_length, sub_sample_lengths
855
+ )
856
+
857
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
858
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
859
+
860
+ if not use_sliding_windows:
861
+ attn_output_unpad = flash_attn_varlen_func(
862
+ query_states,
863
+ key_states,
864
+ value_states,
865
+ cu_seqlens_q=cu_seqlens_q,
866
+ cu_seqlens_k=cu_seqlens_k,
867
+ max_seqlen_q=max_seqlen_in_batch_q,
868
+ max_seqlen_k=max_seqlen_in_batch_k,
869
+ dropout_p=dropout,
870
+ softmax_scale=softmax_scale,
871
+ causal=causal,
872
+ )
873
+ else:
874
+ attn_output_unpad = flash_attn_varlen_func(
875
+ query_states,
876
+ key_states,
877
+ value_states,
878
+ cu_seqlens_q=cu_seqlens_q,
879
+ cu_seqlens_k=cu_seqlens_k,
880
+ max_seqlen_q=max_seqlen_in_batch_q,
881
+ max_seqlen_k=max_seqlen_in_batch_k,
882
+ dropout_p=dropout,
883
+ softmax_scale=softmax_scale,
884
+ causal=causal,
885
+ window_size=(self.config.sliding_window, self.config.sliding_window),
886
+ )
887
+
888
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
889
+ else:
890
+ if not use_sliding_windows:
891
+ attn_output = flash_attn_func(
892
+ query_states,
893
+ key_states,
894
+ value_states,
895
+ dropout,
896
+ softmax_scale=softmax_scale,
897
+ causal=causal,
898
+ )
899
+ else:
900
+ attn_output = flash_attn_func(
901
+ query_states,
902
+ key_states,
903
+ value_states,
904
+ dropout,
905
+ softmax_scale=softmax_scale,
906
+ causal=causal,
907
+ window_size=(self.config.sliding_window, self.config.sliding_window),
908
+ )
909
+
910
+ return attn_output
911
+
912
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
913
+ def _unpad_input_packing(self, query_layer, key_layer, value_layer, attention_mask, query_length, sub_sample_lengths):
914
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
915
+
916
+ # On the first iteration we need to properly re-create the padding mask
917
+ # by slicing it on the proper place
918
+ if kv_seq_len != attention_mask.shape[-1]:
919
+ attention_mask_num_tokens = attention_mask.shape[-1]
920
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
921
+
922
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data_packing(attention_mask, sub_sample_lengths)
923
+
924
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
925
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
926
+
927
+ if query_length == kv_seq_len:
928
+ query_layer = index_first_axis(
929
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
930
+ )
931
+ cu_seqlens_q = cu_seqlens_k
932
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
933
+ indices_q = indices_k
934
+ elif query_length == 1:
935
+ max_seqlen_in_batch_q = 1
936
+ cu_seqlens_q = torch.arange(
937
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
938
+ ) # There is a memcpy here, that is very bad.
939
+ indices_q = cu_seqlens_q[:-1]
940
+ query_layer = query_layer.squeeze(1)
941
+ else:
942
+ # The -q_len: slice assumes left padding.
943
+ attention_mask = attention_mask[:, -query_length:]
944
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
945
+
946
+ return (
947
+ query_layer,
948
+ key_layer,
949
+ value_layer,
950
+ indices_q,
951
+ (cu_seqlens_q, cu_seqlens_k),
952
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
953
+ )
954
+
955
+
956
+ # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Qwen2
957
+ class Qwen2SdpaAttention(Qwen2Attention):
958
+ """
959
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
960
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
961
+ SDPA API.
962
+ """
963
+
964
+ # Adapted from Qwen2Attention.forward
965
+ def forward(
966
+ self,
967
+ hidden_states: torch.Tensor,
968
+ attention_mask: Optional[torch.Tensor] = None,
969
+ position_ids: Optional[torch.LongTensor] = None,
970
+ past_key_value: Optional[Cache] = None,
971
+ output_attentions: bool = False,
972
+ use_cache: bool = False,
973
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
974
+ if output_attentions:
975
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
976
+ logger.warning_once(
977
+ "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
978
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
979
+ )
980
+ return super().forward(
981
+ hidden_states=hidden_states,
982
+ attention_mask=attention_mask,
983
+ position_ids=position_ids,
984
+ past_key_value=past_key_value,
985
+ output_attentions=output_attentions,
986
+ use_cache=use_cache,
987
+ )
988
+
989
+ bsz, q_len, _ = hidden_states.size()
990
+
991
+ query_states = self.q_proj(hidden_states)
992
+ key_states = self.k_proj(hidden_states)
993
+ value_states = self.v_proj(hidden_states)
994
+
995
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
996
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
997
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
998
+
999
+ kv_seq_len = key_states.shape[-2]
1000
+ if past_key_value is not None:
1001
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1002
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1003
+
1004
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1005
+
1006
+ if past_key_value is not None:
1007
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
1008
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1009
+
1010
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
1011
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
1012
+
1013
+ if attention_mask is not None:
1014
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
1015
+ raise ValueError(
1016
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
1017
+ )
1018
+
1019
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
1020
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
1021
+ if query_states.device.type == "cuda" and attention_mask is not None:
1022
+ query_states = query_states.contiguous()
1023
+ key_states = key_states.contiguous()
1024
+ value_states = value_states.contiguous()
1025
+
1026
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
1027
+ query_states,
1028
+ key_states,
1029
+ value_states,
1030
+ attn_mask=attention_mask,
1031
+ dropout_p=self.attention_dropout if self.training else 0.0,
1032
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
1033
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
1034
+ )
1035
+
1036
+ attn_output = attn_output.transpose(1, 2).contiguous()
1037
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
1038
+
1039
+ attn_output = self.o_proj(attn_output)
1040
+
1041
+ return attn_output, None, past_key_value
1042
+
1043
+
1044
+ QWEN2_ATTENTION_CLASSES = {
1045
+ "eager": Qwen2Attention,
1046
+ "flash_attention_2": Qwen2FlashAttention2,
1047
+ "sdpa": Qwen2SdpaAttention,
1048
+ 'flash_attention_2_packing':Qwen2FlashAttention2_packing
1049
+ }
1050
+
1051
+
1052
+ class Qwen2DecoderLayer(nn.Module):
1053
+ def __init__(self, config: Qwen2Config, layer_idx: int):
1054
+ super().__init__()
1055
+ self.hidden_size = config.hidden_size
1056
+
1057
+ if config.use_sliding_window and config.attn_implementation != "flash_attention_2":
1058
+ logger.warning_once(
1059
+ f"Sliding Window Attention is enabled but not implemented for `{config.attn_implementation}`; "
1060
+ "unexpected results may be encountered."
1061
+ )
1062
+
1063
+ self.self_attn = QWEN2_ATTENTION_CLASSES[config.attn_implementation](config, layer_idx)
1064
+
1065
+ self.mlp = Qwen2MLP(config)
1066
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1067
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1068
+
1069
+ def forward(
1070
+ self,
1071
+ hidden_states: torch.Tensor,
1072
+ attention_mask: Optional[torch.Tensor] = None,
1073
+ position_ids: Optional[torch.LongTensor] = None,
1074
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1075
+ sub_sample_lengths=None,
1076
+ output_attentions: Optional[bool] = False,
1077
+ use_cache: Optional[bool] = False,
1078
+ **kwargs,
1079
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1080
+ if "padding_mask" in kwargs:
1081
+ warnings.warn(
1082
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
1083
+ "Please make sure use `attention_mask` instead.`"
1084
+ )
1085
+ """
1086
+ Args:
1087
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1088
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1089
+ `(batch, sequence_length)` where padding elements are indicated by 0.
1090
+ output_attentions (`bool`, *optional*):
1091
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1092
+ returned tensors for more detail.
1093
+ use_cache (`bool`, *optional*):
1094
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1095
+ (see `past_key_values`).
1096
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1097
+ """
1098
+
1099
+ residual = hidden_states
1100
+
1101
+ hidden_states = self.input_layernorm(hidden_states)
1102
+
1103
+ # Self Attention
1104
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1105
+ hidden_states=hidden_states,
1106
+ attention_mask=attention_mask,
1107
+ position_ids=position_ids,
1108
+ past_key_value=past_key_value,
1109
+ output_attentions=output_attentions,
1110
+ use_cache=use_cache,
1111
+ sub_sample_lengths=sub_sample_lengths,
1112
+ )
1113
+ hidden_states = residual + hidden_states
1114
+
1115
+ # Fully Connected
1116
+ residual = hidden_states
1117
+ hidden_states = self.post_attention_layernorm(hidden_states)
1118
+ hidden_states = self.mlp(hidden_states)
1119
+ hidden_states = residual + hidden_states
1120
+
1121
+ outputs = (hidden_states,)
1122
+
1123
+ if output_attentions:
1124
+ outputs += (self_attn_weights,)
1125
+
1126
+ if use_cache:
1127
+ outputs += (present_key_value,)
1128
+
1129
+ return outputs
1130
+
1131
+
1132
+ QWEN2_START_DOCSTRING = r"""
1133
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1134
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1135
+ etc.)
1136
+
1137
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1138
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1139
+ and behavior.
1140
+
1141
+ Parameters:
1142
+ config ([`Qwen2Config`]):
1143
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1144
+ load the weights associated with the model, only the configuration. Check out the
1145
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1146
+ """
1147
+
1148
+
1149
+ @add_start_docstrings(
1150
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
1151
+ QWEN2_START_DOCSTRING,
1152
+ )
1153
+ class Qwen2PreTrainedModel(PreTrainedModel):
1154
+ config_class = Qwen2Config
1155
+ base_model_prefix = "model"
1156
+ supports_gradient_checkpointing = True
1157
+ _no_split_modules = ["Qwen2DecoderLayer"]
1158
+ _skip_keys_device_placement = "past_key_values"
1159
+ _supports_flash_attn_2 = True
1160
+ _supports_sdpa = True
1161
+ _supports_cache_class = True
1162
+
1163
+ def _init_weights(self, module):
1164
+ std = self.config.initializer_range
1165
+ if isinstance(module, nn.Linear):
1166
+ module.weight.data.normal_(mean=0.0, std=std)
1167
+ if module.bias is not None:
1168
+ module.bias.data.zero_()
1169
+ elif isinstance(module, nn.Embedding):
1170
+ module.weight.data.normal_(mean=0.0, std=std)
1171
+ if module.padding_idx is not None:
1172
+ module.weight.data[module.padding_idx].zero_()
1173
+
1174
+
1175
+ QWEN2_INPUTS_DOCSTRING = r"""
1176
+ Args:
1177
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1178
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1179
+ it.
1180
+
1181
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1182
+ [`PreTrainedTokenizer.__call__`] for details.
1183
+
1184
+ [What are input IDs?](../glossary#input-ids)
1185
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1186
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1187
+
1188
+ - 1 for tokens that are **not masked**,
1189
+ - 0 for tokens that are **masked**.
1190
+
1191
+ [What are attention masks?](../glossary#attention-mask)
1192
+
1193
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1194
+ [`PreTrainedTokenizer.__call__`] for details.
1195
+
1196
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1197
+ `past_key_values`).
1198
+
1199
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1200
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1201
+ information on the default strategy.
1202
+
1203
+ - 1 indicates the head is **not masked**,
1204
+ - 0 indicates the head is **masked**.
1205
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1206
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1207
+ config.n_positions - 1]`.
1208
+
1209
+ [What are position IDs?](../glossary#position-ids)
1210
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1211
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1212
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1213
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1214
+
1215
+ Two formats are allowed:
1216
+ - a [`~cache_utils.Cache`] instance;
1217
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1218
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1219
+ cache format.
1220
+
1221
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1222
+ legacy cache format will be returned.
1223
+
1224
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1225
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1226
+ of shape `(batch_size, sequence_length)`.
1227
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1228
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1229
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1230
+ model's internal embedding lookup matrix.
1231
+ use_cache (`bool`, *optional*):
1232
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1233
+ `past_key_values`).
1234
+ output_attentions (`bool`, *optional*):
1235
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1236
+ tensors for more detail.
1237
+ output_hidden_states (`bool`, *optional*):
1238
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1239
+ more detail.
1240
+ return_dict (`bool`, *optional*):
1241
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1242
+ """
1243
+
1244
+
1245
+ @add_start_docstrings(
1246
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
1247
+ QWEN2_START_DOCSTRING,
1248
+ )
1249
+ class Qwen2Model(Qwen2PreTrainedModel):
1250
+ """
1251
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
1252
+
1253
+ Args:
1254
+ config: Qwen2Config
1255
+ """
1256
+
1257
+ def __init__(self, config: Qwen2Config):
1258
+ super().__init__(config)
1259
+ self.padding_idx = config.pad_token_id
1260
+ self.vocab_size = config.vocab_size
1261
+
1262
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1263
+ self.layers = nn.ModuleList(
1264
+ [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1265
+ )
1266
+ self.attn_implementation = config.attn_implementation
1267
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1268
+
1269
+ self.gradient_checkpointing = False
1270
+ # Initialize weights and apply final processing
1271
+ self.post_init()
1272
+
1273
+ def get_input_embeddings(self):
1274
+ return self.embed_tokens
1275
+
1276
+ def set_input_embeddings(self, value):
1277
+ self.embed_tokens = value
1278
+
1279
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1280
+ def forward(
1281
+ self,
1282
+ input_ids: torch.LongTensor = None,
1283
+ attention_mask: Optional[torch.Tensor] = None,
1284
+ position_ids: Optional[torch.LongTensor] = None,
1285
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1286
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1287
+ use_cache: Optional[bool] = None,
1288
+ output_attentions: Optional[bool] = None,
1289
+ output_hidden_states: Optional[bool] = None,
1290
+ return_dict: Optional[bool] = None,
1291
+ sub_sample_lengths=None,
1292
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1293
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1294
+ output_hidden_states = (
1295
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1296
+ )
1297
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1298
+
1299
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1300
+
1301
+ # retrieve input_ids and inputs_embeds
1302
+ if input_ids is not None and inputs_embeds is not None:
1303
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1304
+ elif input_ids is not None:
1305
+ batch_size, seq_length = input_ids.shape
1306
+ elif inputs_embeds is not None:
1307
+ batch_size, seq_length, _ = inputs_embeds.shape
1308
+ else:
1309
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1310
+
1311
+ if self.gradient_checkpointing and self.training:
1312
+ if use_cache:
1313
+ logger.warning_once(
1314
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1315
+ )
1316
+ use_cache = False
1317
+
1318
+ past_key_values_length = 0
1319
+
1320
+ if use_cache:
1321
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1322
+ if use_legacy_cache:
1323
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1324
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1325
+
1326
+ if position_ids is None:
1327
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1328
+ position_ids = torch.arange(
1329
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1330
+ )
1331
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1332
+ else:
1333
+ position_ids = position_ids.view(-1, seq_length).long()
1334
+
1335
+ if inputs_embeds is None:
1336
+ inputs_embeds = self.embed_tokens(input_ids)
1337
+
1338
+ if attention_mask is not None and self.attn_implementation == "flash_attention_2" and use_cache:
1339
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1340
+ if is_padding_right:
1341
+ raise ValueError(
1342
+ "You are attempting to perform batched generation with padding_side='right'"
1343
+ " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
1344
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1345
+ )
1346
+
1347
+ if self.attn_implementation == "flash_attention_2" or self.config.attn_implementation =='flash_attention_2_packing':
1348
+ # 2d mask is passed through the layers
1349
+ if attention_mask is not None:
1350
+ if attention_mask.dtype == torch.long:
1351
+ pass
1352
+ # attention_mask = attention_mask
1353
+ else:
1354
+ attention_mask = attention_mask if (0 in attention_mask) else None
1355
+
1356
+ elif self.attn_implementation == "sdpa" and not output_attentions:
1357
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1358
+ # the manual implementation that requires a 4D causal mask in all cases.
1359
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1360
+ attention_mask,
1361
+ (batch_size, seq_length),
1362
+ inputs_embeds,
1363
+ past_key_values_length,
1364
+ )
1365
+ else:
1366
+ # 4d mask is passed through the layers
1367
+ attention_mask = _prepare_4d_causal_attention_mask(
1368
+ attention_mask,
1369
+ (batch_size, seq_length),
1370
+ inputs_embeds,
1371
+ past_key_values_length,
1372
+ sliding_window=self.config.sliding_window,
1373
+ )
1374
+
1375
+ hidden_states = inputs_embeds
1376
+
1377
+ # decoder layers
1378
+ all_hidden_states = () if output_hidden_states else None
1379
+ all_self_attns = () if output_attentions else None
1380
+ next_decoder_cache = None
1381
+
1382
+ for decoder_layer in self.layers:
1383
+ if output_hidden_states:
1384
+ all_hidden_states += (hidden_states,)
1385
+ if self.gradient_checkpointing and self.training:
1386
+ layer_outputs = self._gradient_checkpointing_func(
1387
+ decoder_layer.__call__,
1388
+ hidden_states,
1389
+ attention_mask,
1390
+ position_ids,
1391
+ past_key_values,
1392
+ sub_sample_lengths,
1393
+ output_attentions,
1394
+ use_cache,
1395
+ )
1396
+ else:
1397
+ layer_outputs = decoder_layer(
1398
+ hidden_states,
1399
+ attention_mask=attention_mask,
1400
+ position_ids=position_ids,
1401
+ past_key_value=past_key_values,
1402
+ sub_sample_lengths=sub_sample_lengths,
1403
+ output_attentions=output_attentions,
1404
+ use_cache=use_cache,
1405
+ )
1406
+
1407
+ hidden_states = layer_outputs[0]
1408
+
1409
+ if use_cache:
1410
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1411
+
1412
+ if output_attentions:
1413
+ all_self_attns += (layer_outputs[1],)
1414
+
1415
+ hidden_states = self.norm(hidden_states)
1416
+
1417
+ # add hidden states from the last decoder layer
1418
+ if output_hidden_states:
1419
+ all_hidden_states += (hidden_states,)
1420
+
1421
+ next_cache = None
1422
+ if use_cache:
1423
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1424
+
1425
+ if not return_dict:
1426
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1427
+ return BaseModelOutputWithPast(
1428
+ last_hidden_state=hidden_states,
1429
+ past_key_values=next_cache,
1430
+ hidden_states=all_hidden_states,
1431
+ attentions=all_self_attns,
1432
+ )
1433
+
1434
+
1435
+ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1436
+ _tied_weights_keys = ["lm_head.weight"]
1437
+
1438
+ def __init__(self, config):
1439
+ super().__init__(config)
1440
+ self.model = Qwen2Model(config)
1441
+ self.vocab_size = config.vocab_size
1442
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1443
+
1444
+ # Initialize weights and apply final processing
1445
+ self.post_init()
1446
+ self.support_packing = True
1447
+
1448
+ def get_input_embeddings(self):
1449
+ return self.model.embed_tokens
1450
+
1451
+ def set_input_embeddings(self, value):
1452
+ self.model.embed_tokens = value
1453
+
1454
+ def get_output_embeddings(self):
1455
+ return self.lm_head
1456
+
1457
+ def set_output_embeddings(self, new_embeddings):
1458
+ self.lm_head = new_embeddings
1459
+
1460
+ def set_decoder(self, decoder):
1461
+ self.model = decoder
1462
+
1463
+ def get_decoder(self):
1464
+ return self.model
1465
+
1466
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1467
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1468
+ def forward(
1469
+ self,
1470
+ input_ids: torch.LongTensor = None,
1471
+ attention_mask: Optional[torch.Tensor] = None,
1472
+ position_ids: Optional[torch.LongTensor] = None,
1473
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1474
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1475
+ labels: Optional[torch.LongTensor] = None,
1476
+ use_cache: Optional[bool] = None,
1477
+ output_attentions: Optional[bool] = None,
1478
+ output_hidden_states: Optional[bool] = None,
1479
+ return_dict: Optional[bool] = None,
1480
+ sub_sample_lengths=None,
1481
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1482
+ r"""
1483
+ Args:
1484
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1485
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1486
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1487
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1488
+
1489
+ Returns:
1490
+
1491
+ Example:
1492
+
1493
+ ```python
1494
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
1495
+
1496
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1497
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1498
+
1499
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1500
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1501
+
1502
+ >>> # Generate
1503
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1504
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1505
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1506
+ ```"""
1507
+
1508
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1509
+ output_hidden_states = (
1510
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1511
+ )
1512
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1513
+
1514
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1515
+ outputs = self.model(
1516
+ input_ids=input_ids,
1517
+ attention_mask=attention_mask,
1518
+ position_ids=position_ids,
1519
+ past_key_values=past_key_values,
1520
+ inputs_embeds=inputs_embeds,
1521
+ use_cache=use_cache,
1522
+ output_attentions=output_attentions,
1523
+ output_hidden_states=output_hidden_states,
1524
+ return_dict=return_dict,
1525
+ sub_sample_lengths=sub_sample_lengths
1526
+ )
1527
+
1528
+ hidden_states = outputs[0]
1529
+ logits = self.lm_head(hidden_states)
1530
+ logits = logits.float()
1531
+
1532
+ loss = None
1533
+ if labels is not None:
1534
+ # Shift so that tokens < n predict n
1535
+ shift_logits = logits[..., :-1, :].contiguous()
1536
+ shift_labels = labels[..., 1:].contiguous()
1537
+ # Flatten the tokens
1538
+ loss_fct = CrossEntropyLoss()
1539
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1540
+ shift_labels = shift_labels.view(-1)
1541
+ # Enable model parallelism
1542
+ shift_labels = shift_labels.to(shift_logits.device)
1543
+ loss = loss_fct(shift_logits, shift_labels)
1544
+
1545
+ if not return_dict:
1546
+ output = (logits,) + outputs[1:]
1547
+ return (loss,) + output if loss is not None else output
1548
+
1549
+ return CausalLMOutputWithPast(
1550
+ loss=loss,
1551
+ logits=logits,
1552
+ past_key_values=outputs.past_key_values,
1553
+ hidden_states=outputs.hidden_states,
1554
+ attentions=outputs.attentions,
1555
+ )
1556
+
1557
+ def prepare_inputs_for_generation(
1558
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1559
+ ):
1560
+ # Omit tokens covered by past_key_values
1561
+ if past_key_values is not None:
1562
+ if isinstance(past_key_values, Cache):
1563
+ cache_length = past_key_values.get_seq_length()
1564
+ past_length = past_key_values.seen_tokens
1565
+ max_cache_length = past_key_values.get_max_length()
1566
+ else:
1567
+ cache_length = past_length = past_key_values[0][0].shape[2]
1568
+ max_cache_length = None
1569
+
1570
+ # Keep only the unprocessed tokens:
1571
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1572
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1573
+ # input)
1574
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1575
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1576
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1577
+ # input_ids based on the past_length.
1578
+ elif past_length < input_ids.shape[1]:
1579
+ input_ids = input_ids[:, past_length:]
1580
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1581
+
1582
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1583
+ if (
1584
+ max_cache_length is not None
1585
+ and attention_mask is not None
1586
+ and cache_length + input_ids.shape[1] > max_cache_length
1587
+ ):
1588
+ attention_mask = attention_mask[:, -max_cache_length:]
1589
+
1590
+ position_ids = kwargs.get("position_ids", None)
1591
+ if attention_mask is not None and position_ids is None:
1592
+ # create position_ids on the fly for batch generation
1593
+ position_ids = attention_mask.long().cumsum(-1) - 1
1594
+ position_ids.masked_fill_(attention_mask == 0, 1)
1595
+ if past_key_values:
1596
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1597
+
1598
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1599
+ if inputs_embeds is not None and past_key_values is None:
1600
+ model_inputs = {"inputs_embeds": inputs_embeds}
1601
+ else:
1602
+ model_inputs = {"input_ids": input_ids}
1603
+
1604
+ model_inputs.update(
1605
+ {
1606
+ "position_ids": position_ids,
1607
+ "past_key_values": past_key_values,
1608
+ "use_cache": kwargs.get("use_cache"),
1609
+ "attention_mask": attention_mask,
1610
+ }
1611
+ )
1612
+ return model_inputs
1613
+
1614
+ @staticmethod
1615
+ def _reorder_cache(past_key_values, beam_idx):
1616
+ reordered_past = ()
1617
+ for layer_past in past_key_values:
1618
+ reordered_past += (
1619
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1620
+ )
1621
+ return reordered_past
1622
+
1623
+
1624
+ @add_start_docstrings(
1625
+ """
1626
+ The Qwen2 Model transformer with a sequence classification head on top (linear layer).
1627
+
1628
+ [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1629
+ (e.g. GPT-2) do.
1630
+
1631
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1632
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1633
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1634
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1635
+ each row of the batch).
1636
+ """,
1637
+ QWEN2_START_DOCSTRING,
1638
+ )
1639
+ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1640
+ def __init__(self, config):
1641
+ super().__init__(config)
1642
+ self.num_labels = config.num_labels
1643
+ self.model = Qwen2Model(config)
1644
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1645
+
1646
+ # Initialize weights and apply final processing
1647
+ self.post_init()
1648
+
1649
+ def get_input_embeddings(self):
1650
+ return self.model.embed_tokens
1651
+
1652
+ def set_input_embeddings(self, value):
1653
+ self.model.embed_tokens = value
1654
+
1655
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1656
+ def forward(
1657
+ self,
1658
+ input_ids: torch.LongTensor = None,
1659
+ attention_mask: Optional[torch.Tensor] = None,
1660
+ position_ids: Optional[torch.LongTensor] = None,
1661
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1662
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1663
+ labels: Optional[torch.LongTensor] = None,
1664
+ use_cache: Optional[bool] = None,
1665
+ output_attentions: Optional[bool] = None,
1666
+ output_hidden_states: Optional[bool] = None,
1667
+ return_dict: Optional[bool] = None,
1668
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1669
+ r"""
1670
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1671
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1672
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1673
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1674
+ """
1675
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1676
+
1677
+ transformer_outputs = self.model(
1678
+ input_ids,
1679
+ attention_mask=attention_mask,
1680
+ position_ids=position_ids,
1681
+ past_key_values=past_key_values,
1682
+ inputs_embeds=inputs_embeds,
1683
+ use_cache=use_cache,
1684
+ output_attentions=output_attentions,
1685
+ output_hidden_states=output_hidden_states,
1686
+ return_dict=return_dict,
1687
+ )
1688
+ hidden_states = transformer_outputs[0]
1689
+ logits = self.score(hidden_states)
1690
+
1691
+ if input_ids is not None:
1692
+ batch_size = input_ids.shape[0]
1693
+ else:
1694
+ batch_size = inputs_embeds.shape[0]
1695
+
1696
+ if self.config.pad_token_id is None and batch_size != 1:
1697
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1698
+ if self.config.pad_token_id is None:
1699
+ sequence_lengths = -1
1700
+ else:
1701
+ if input_ids is not None:
1702
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1703
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1704
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1705
+ sequence_lengths = sequence_lengths.to(logits.device)
1706
+ else:
1707
+ sequence_lengths = -1
1708
+
1709
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1710
+
1711
+ loss = None
1712
+ if labels is not None:
1713
+ labels = labels.to(logits.device)
1714
+ if self.config.problem_type is None:
1715
+ if self.num_labels == 1:
1716
+ self.config.problem_type = "regression"
1717
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1718
+ self.config.problem_type = "single_label_classification"
1719
+ else:
1720
+ self.config.problem_type = "multi_label_classification"
1721
+
1722
+ if self.config.problem_type == "regression":
1723
+ loss_fct = MSELoss()
1724
+ if self.num_labels == 1:
1725
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1726
+ else:
1727
+ loss = loss_fct(pooled_logits, labels)
1728
+ elif self.config.problem_type == "single_label_classification":
1729
+ loss_fct = CrossEntropyLoss()
1730
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1731
+ elif self.config.problem_type == "multi_label_classification":
1732
+ loss_fct = BCEWithLogitsLoss()
1733
+ loss = loss_fct(pooled_logits, labels)
1734
+ if not return_dict:
1735
+ output = (pooled_logits,) + transformer_outputs[1:]
1736
+ return ((loss,) + output) if loss is not None else output
1737
+
1738
+ return SequenceClassifierOutputWithPast(
1739
+ loss=loss,
1740
+ logits=pooled_logits,
1741
+ past_key_values=transformer_outputs.past_key_values,
1742
+ hidden_states=transformer_outputs.hidden_states,
1743
+ attentions=transformer_outputs.attentions,
1744
+ )
modeling_siglip.py ADDED
@@ -0,0 +1,1241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Eagle2
3
+ # Copyright (c) 2025 NVIDIA
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Support flash-attention in SigLIP
6
+ # --------------------------------------------------------
7
+
8
+
9
+ # coding=utf-8
10
+ # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
11
+ #
12
+ # Licensed under the Apache License, Version 2.0 (the "License");
13
+ # you may not use this file except in compliance with the License.
14
+ # You may obtain a copy of the License at
15
+ #
16
+ # http://www.apache.org/licenses/LICENSE-2.0
17
+ #
18
+ # Unless required by applicable law or agreed to in writing, software
19
+ # distributed under the License is distributed on an "AS IS" BASIS,
20
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
+ # See the License for the specific language governing permissions and
22
+ # limitations under the License.
23
+ """ PyTorch Siglip model."""
24
+
25
+
26
+ import math
27
+ import warnings
28
+ from dataclasses import dataclass
29
+ from typing import Any, Optional, Tuple, Union
30
+ from einops import rearrange
31
+ import numpy as np
32
+ import torch
33
+ import torch.utils.checkpoint
34
+ from torch import nn
35
+ from torch.nn.init import _calculate_fan_in_and_fan_out
36
+
37
+ from transformers.activations import ACT2FN
38
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
39
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
40
+ from transformers.modeling_utils import PreTrainedModel
41
+ from transformers.utils import (
42
+ ModelOutput,
43
+ add_start_docstrings,
44
+ add_start_docstrings_to_model_forward,
45
+ logging,
46
+ replace_return_docstrings,
47
+ )
48
+ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
49
+
50
+ try:
51
+ from .flash_attention import FlashAttention
52
+ has_flash_attn = True
53
+ except:
54
+ print('FlashAttention is not installed.')
55
+ has_flash_attn = False
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+ _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
60
+
61
+ SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
62
+ "google/siglip-base-patch16-224",
63
+ # See all SigLIP models at https://huggingface.co/models?filter=siglip
64
+ ]
65
+
66
+
67
+ def _trunc_normal_(tensor, mean, std, a, b):
68
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
69
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
70
+ def norm_cdf(x):
71
+ # Computes standard normal cumulative distribution function
72
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
73
+
74
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
75
+ warnings.warn(
76
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
77
+ "The distribution of values may be incorrect.",
78
+ stacklevel=2,
79
+ )
80
+
81
+ # Values are generated by using a truncated uniform distribution and
82
+ # then using the inverse CDF for the normal distribution.
83
+ # Get upper and lower cdf values
84
+ l = norm_cdf((a - mean) / std)
85
+ u = norm_cdf((b - mean) / std)
86
+
87
+ # Uniformly fill tensor with values from [l, u], then translate to
88
+ # [2l-1, 2u-1].
89
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
90
+
91
+ # Use inverse cdf transform for normal distribution to get truncated
92
+ # standard normal
93
+ tensor.erfinv_()
94
+
95
+ # Transform to proper mean, std
96
+ tensor.mul_(std * math.sqrt(2.0))
97
+ tensor.add_(mean)
98
+
99
+ # Clamp to ensure it's in the proper range
100
+ tensor.clamp_(min=a, max=b)
101
+
102
+
103
+ def trunc_normal_tf_(
104
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
105
+ ) -> torch.Tensor:
106
+ """Fills the input Tensor with values drawn from a truncated
107
+ normal distribution. The values are effectively drawn from the
108
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
109
+ with values outside :math:`[a, b]` redrawn until they are within
110
+ the bounds. The method used for generating the random values works
111
+ best when :math:`a \\leq \text{mean} \\leq b`.
112
+
113
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
114
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
115
+ and the result is subsquently scaled and shifted by the mean and std args.
116
+
117
+ Args:
118
+ tensor: an n-dimensional `torch.Tensor`
119
+ mean: the mean of the normal distribution
120
+ std: the standard deviation of the normal distribution
121
+ a: the minimum cutoff value
122
+ b: the maximum cutoff value
123
+ """
124
+ with torch.no_grad():
125
+ _trunc_normal_(tensor, 0, 1.0, a, b)
126
+ tensor.mul_(std).add_(mean)
127
+
128
+
129
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
130
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
131
+ if mode == "fan_in":
132
+ denom = fan_in
133
+ elif mode == "fan_out":
134
+ denom = fan_out
135
+ elif mode == "fan_avg":
136
+ denom = (fan_in + fan_out) / 2
137
+
138
+ variance = scale / denom
139
+
140
+ if distribution == "truncated_normal":
141
+ # constant is stddev of standard normal truncated to (-2, 2)
142
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
143
+ elif distribution == "normal":
144
+ with torch.no_grad():
145
+ tensor.normal_(std=math.sqrt(variance))
146
+ elif distribution == "uniform":
147
+ bound = math.sqrt(3 * variance)
148
+ with torch.no_grad():
149
+ tensor.uniform_(-bound, bound)
150
+ else:
151
+ raise ValueError(f"invalid distribution {distribution}")
152
+
153
+
154
+ def lecun_normal_(tensor):
155
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
156
+
157
+
158
+ def default_flax_embed_init(tensor):
159
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
160
+
161
+
162
+ @dataclass
163
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
164
+ class SiglipVisionModelOutput(ModelOutput):
165
+ """
166
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
167
+
168
+ Args:
169
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
170
+ The image embeddings obtained by applying the projection layer to the pooler_output.
171
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
172
+ Sequence of hidden-states at the output of the last layer of the model.
173
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
174
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
175
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
176
+
177
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
178
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
179
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
180
+ sequence_length)`.
181
+
182
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
183
+ heads.
184
+ """
185
+
186
+ image_embeds: Optional[torch.FloatTensor] = None
187
+ last_hidden_state: torch.FloatTensor = None
188
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
189
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
190
+
191
+
192
+ @dataclass
193
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
194
+ class SiglipTextModelOutput(ModelOutput):
195
+ """
196
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
197
+
198
+ Args:
199
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
200
+ The text embeddings obtained by applying the projection layer to the pooler_output.
201
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
202
+ Sequence of hidden-states at the output of the last layer of the model.
203
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
204
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
205
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
206
+
207
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
208
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
209
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
210
+ sequence_length)`.
211
+
212
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
213
+ heads.
214
+ """
215
+
216
+ text_embeds: Optional[torch.FloatTensor] = None
217
+ last_hidden_state: torch.FloatTensor = None
218
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
219
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
220
+
221
+
222
+ @dataclass
223
+ # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
224
+ class SiglipOutput(ModelOutput):
225
+ """
226
+ Args:
227
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
228
+ Contrastive loss for image-text similarity.
229
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
230
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
231
+ similarity scores.
232
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
233
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
234
+ similarity scores.
235
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
236
+ The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
237
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
238
+ The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
239
+ text_model_output(`BaseModelOutputWithPooling`):
240
+ The output of the [`SiglipTextModel`].
241
+ vision_model_output(`BaseModelOutputWithPooling`):
242
+ The output of the [`SiglipVisionModel`].
243
+ """
244
+
245
+ loss: Optional[torch.FloatTensor] = None
246
+ logits_per_image: torch.FloatTensor = None
247
+ logits_per_text: torch.FloatTensor = None
248
+ text_embeds: torch.FloatTensor = None
249
+ image_embeds: torch.FloatTensor = None
250
+ text_model_output: BaseModelOutputWithPooling = None
251
+ vision_model_output: BaseModelOutputWithPooling = None
252
+
253
+ def to_tuple(self) -> Tuple[Any]:
254
+ return tuple(
255
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
256
+ for k in self.keys()
257
+ )
258
+
259
+
260
+ class SiglipVisionEmbeddings(nn.Module):
261
+ def __init__(self, config: SiglipVisionConfig):
262
+ super().__init__()
263
+ self.config = config
264
+ self.embed_dim = config.hidden_size
265
+ self.image_size = config.image_size
266
+ self.patch_size = config.patch_size
267
+
268
+ self.patch_embedding = nn.Conv2d(
269
+ in_channels=config.num_channels,
270
+ out_channels=self.embed_dim,
271
+ kernel_size=self.patch_size,
272
+ stride=self.patch_size,
273
+ padding="valid",
274
+ )
275
+
276
+ self.num_patches = (self.image_size // self.patch_size) ** 2
277
+ self.num_positions = self.num_patches
278
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
279
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
280
+
281
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
282
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
283
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
284
+
285
+ embeddings = embeddings + self.position_embedding(self.position_ids)
286
+ return embeddings
287
+
288
+
289
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
290
+ class SiglipTextEmbeddings(nn.Module):
291
+ def __init__(self, config: SiglipTextConfig):
292
+ super().__init__()
293
+ embed_dim = config.hidden_size
294
+
295
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
296
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
297
+
298
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
299
+ self.register_buffer(
300
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
301
+ )
302
+
303
+ def forward(
304
+ self,
305
+ input_ids: Optional[torch.LongTensor] = None,
306
+ position_ids: Optional[torch.LongTensor] = None,
307
+ inputs_embeds: Optional[torch.FloatTensor] = None,
308
+ ) -> torch.Tensor:
309
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
310
+
311
+ if position_ids is None:
312
+ position_ids = self.position_ids[:, :seq_length]
313
+
314
+ if inputs_embeds is None:
315
+ inputs_embeds = self.token_embedding(input_ids)
316
+
317
+ position_embeddings = self.position_embedding(position_ids)
318
+ embeddings = inputs_embeds + position_embeddings
319
+
320
+ return embeddings
321
+
322
+
323
+ class SiglipAttention(nn.Module):
324
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
325
+
326
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
327
+ def __init__(self, config):
328
+ super().__init__()
329
+ self.config = config
330
+ self.embed_dim = config.hidden_size
331
+ self.num_heads = config.num_attention_heads
332
+ self.head_dim = self.embed_dim // self.num_heads
333
+ if self.head_dim * self.num_heads != self.embed_dim:
334
+ raise ValueError(
335
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
336
+ f" {self.num_heads})."
337
+ )
338
+ self.scale = self.head_dim**-0.5
339
+ self.dropout = config.attention_dropout
340
+
341
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
342
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
343
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
344
+ # self.use_flash_attn = config.use_flash_attn and has_flash_attn
345
+ self.use_flash_attn = True if has_flash_attn else False
346
+ if self.use_flash_attn:
347
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
348
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
349
+
350
+ def _flash_attn(self,
351
+ hidden_states: torch.Tensor,
352
+ attention_mask: Optional[torch.Tensor] = None,
353
+ output_attentions: Optional[bool] = False,
354
+ key_padding_mask=None,
355
+ need_weights=False
356
+ ):
357
+
358
+ batch_size, q_len, _ = hidden_states.size()
359
+
360
+ query_states = self.q_proj(hidden_states)
361
+ key_states = self.k_proj(hidden_states)
362
+ value_states = self.v_proj(hidden_states)
363
+
364
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
365
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
366
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
367
+
368
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
369
+ context, attn_weights = self.inner_attn(
370
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
371
+ )
372
+ attn_output = self.out_proj(rearrange(context, 'b s h d -> b s (h d)'))
373
+
374
+ return attn_output, attn_weights
375
+
376
+ def forward(
377
+ self,
378
+ hidden_states: torch.Tensor,
379
+ attention_mask: Optional[torch.Tensor] = None,
380
+ output_attentions: Optional[bool] = False,
381
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
382
+ """Input shape: Batch x Time x Channel"""
383
+ if self.use_flash_attn:
384
+ return self._flash_attn(hidden_states)
385
+ else:
386
+ return self._vanilla_attn(hidden_states, attention_mask, output_attentions)
387
+
388
+ def _vanilla_attn(self, hidden_states, attention_mask=None, output_attentions=False):
389
+ batch_size, q_len, _ = hidden_states.size()
390
+
391
+ query_states = self.q_proj(hidden_states)
392
+ key_states = self.k_proj(hidden_states)
393
+ value_states = self.v_proj(hidden_states)
394
+
395
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
396
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
397
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
398
+
399
+ k_v_seq_len = key_states.shape[-2]
400
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
401
+
402
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
403
+ raise ValueError(
404
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
405
+ f" {attn_weights.size()}"
406
+ )
407
+
408
+ if attention_mask is not None:
409
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
410
+ raise ValueError(
411
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
412
+ )
413
+ attn_weights = attn_weights + attention_mask
414
+
415
+ # upcast attention to fp32
416
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
417
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
418
+ attn_output = torch.matmul(attn_weights, value_states)
419
+
420
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
421
+ raise ValueError(
422
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
423
+ f" {attn_output.size()}"
424
+ )
425
+
426
+ attn_output = attn_output.transpose(1, 2).contiguous()
427
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
428
+
429
+ attn_output = self.out_proj(attn_output)
430
+
431
+ return attn_output, attn_weights
432
+
433
+
434
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
435
+ class SiglipMLP(nn.Module):
436
+ def __init__(self, config):
437
+ super().__init__()
438
+ self.config = config
439
+ self.activation_fn = ACT2FN[config.hidden_act]
440
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
441
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
442
+
443
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
444
+ hidden_states = self.fc1(hidden_states)
445
+ hidden_states = self.activation_fn(hidden_states)
446
+ hidden_states = self.fc2(hidden_states)
447
+ return hidden_states
448
+
449
+
450
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
451
+ class SiglipEncoderLayer(nn.Module):
452
+ def __init__(self, config: SiglipConfig):
453
+ super().__init__()
454
+ self.embed_dim = config.hidden_size
455
+ self.self_attn = SiglipAttention(config)
456
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
457
+ self.mlp = SiglipMLP(config)
458
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
459
+
460
+ # Ignore copy
461
+ def forward(
462
+ self,
463
+ hidden_states: torch.Tensor,
464
+ attention_mask: torch.Tensor,
465
+ output_attentions: Optional[bool] = False,
466
+ ) -> Tuple[torch.FloatTensor]:
467
+ """
468
+ Args:
469
+ hidden_states (`torch.FloatTensor`):
470
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
471
+ attention_mask (`torch.FloatTensor`):
472
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
473
+ output_attentions (`bool`, *optional*, defaults to `False`):
474
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
475
+ returned tensors for more detail.
476
+ """
477
+ residual = hidden_states
478
+
479
+ hidden_states = self.layer_norm1(hidden_states)
480
+ hidden_states, attn_weights = self.self_attn(
481
+ hidden_states=hidden_states,
482
+ attention_mask=attention_mask,
483
+ output_attentions=output_attentions,
484
+ )
485
+ hidden_states = residual + hidden_states
486
+
487
+ residual = hidden_states
488
+ hidden_states = self.layer_norm2(hidden_states)
489
+ hidden_states = self.mlp(hidden_states)
490
+ hidden_states = residual + hidden_states
491
+
492
+ outputs = (hidden_states,)
493
+
494
+ if output_attentions:
495
+ outputs += (attn_weights,)
496
+
497
+ return outputs
498
+
499
+
500
+ class SiglipPreTrainedModel(PreTrainedModel):
501
+ """
502
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
503
+ models.
504
+ """
505
+
506
+ config_class = SiglipConfig
507
+ base_model_prefix = "siglip"
508
+ supports_gradient_checkpointing = True
509
+
510
+ def _init_weights(self, module):
511
+ """Initialize the weights"""
512
+ if isinstance(module, SiglipVisionEmbeddings):
513
+ width = (
514
+ self.config.vision_config.hidden_size
515
+ if isinstance(self.config, SiglipConfig)
516
+ else self.config.hidden_size
517
+ )
518
+ nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
519
+ elif isinstance(module, nn.Embedding):
520
+ default_flax_embed_init(module.weight)
521
+ elif isinstance(module, SiglipAttention):
522
+ nn.init.xavier_uniform_(module.q_proj.weight)
523
+ nn.init.xavier_uniform_(module.k_proj.weight)
524
+ nn.init.xavier_uniform_(module.v_proj.weight)
525
+ nn.init.xavier_uniform_(module.out_proj.weight)
526
+ nn.init.zeros_(module.q_proj.bias)
527
+ nn.init.zeros_(module.k_proj.bias)
528
+ nn.init.zeros_(module.v_proj.bias)
529
+ nn.init.zeros_(module.out_proj.bias)
530
+ elif isinstance(module, SiglipMLP):
531
+ nn.init.xavier_uniform_(module.fc1.weight)
532
+ nn.init.xavier_uniform_(module.fc2.weight)
533
+ nn.init.normal_(module.fc1.bias, std=1e-6)
534
+ nn.init.normal_(module.fc2.bias, std=1e-6)
535
+ elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
536
+ nn.init.xavier_uniform_(module.probe.data)
537
+ nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
538
+ nn.init.zeros_(module.attention.in_proj_bias.data)
539
+ elif isinstance(module, SiglipModel):
540
+ logit_scale_init = torch.log(torch.tensor(1.0))
541
+ module.logit_scale.data.fill_(logit_scale_init)
542
+ module.logit_bias.data.zero_()
543
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
544
+ lecun_normal_(module.weight)
545
+ if module.bias is not None:
546
+ nn.init.zeros_(module.bias)
547
+ elif isinstance(module, nn.LayerNorm):
548
+ module.bias.data.zero_()
549
+ module.weight.data.fill_(1.0)
550
+
551
+
552
+ SIGLIP_START_DOCSTRING = r"""
553
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
554
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
555
+ etc.)
556
+
557
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
558
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
559
+ and behavior.
560
+
561
+ Parameters:
562
+ config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
563
+ Initializing with a config file does not load the weights associated with the model, only the
564
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
565
+ """
566
+
567
+ SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
568
+ Args:
569
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
570
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
571
+ it.
572
+
573
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
574
+ [`PreTrainedTokenizer.__call__`] for details.
575
+
576
+ [What are input IDs?](../glossary#input-ids)
577
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
578
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
579
+
580
+ - 1 for tokens that are **not masked**,
581
+ - 0 for tokens that are **masked**.
582
+
583
+ [What are attention masks?](../glossary#attention-mask)
584
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
585
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
586
+ config.max_position_embeddings - 1]`.
587
+
588
+ [What are position IDs?](../glossary#position-ids)
589
+ output_attentions (`bool`, *optional*):
590
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
591
+ tensors for more detail.
592
+ output_hidden_states (`bool`, *optional*):
593
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
594
+ more detail.
595
+ return_dict (`bool`, *optional*):
596
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
597
+ """
598
+
599
+ SIGLIP_VISION_INPUTS_DOCSTRING = r"""
600
+ Args:
601
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
602
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
603
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
604
+ output_attentions (`bool`, *optional*):
605
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
606
+ tensors for more detail.
607
+ output_hidden_states (`bool`, *optional*):
608
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
609
+ more detail.
610
+ return_dict (`bool`, *optional*):
611
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
612
+ """
613
+
614
+ SIGLIP_INPUTS_DOCSTRING = r"""
615
+ Args:
616
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
617
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
618
+ it.
619
+
620
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
621
+ [`PreTrainedTokenizer.__call__`] for details.
622
+
623
+ [What are input IDs?](../glossary#input-ids)
624
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
625
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
626
+
627
+ - 1 for tokens that are **not masked**,
628
+ - 0 for tokens that are **masked**.
629
+
630
+ [What are attention masks?](../glossary#attention-mask)
631
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
632
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
633
+ config.max_position_embeddings - 1]`.
634
+
635
+ [What are position IDs?](../glossary#position-ids)
636
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
637
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
638
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
639
+ return_loss (`bool`, *optional*):
640
+ Whether or not to return the contrastive loss.
641
+ output_attentions (`bool`, *optional*):
642
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
643
+ tensors for more detail.
644
+ output_hidden_states (`bool`, *optional*):
645
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
646
+ more detail.
647
+ return_dict (`bool`, *optional*):
648
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
649
+ """
650
+
651
+
652
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
653
+ class SiglipEncoder(nn.Module):
654
+ """
655
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
656
+ [`SiglipEncoderLayer`].
657
+
658
+ Args:
659
+ config: SiglipConfig
660
+ """
661
+
662
+ def __init__(self, config: SiglipConfig):
663
+ super().__init__()
664
+ self.config = config
665
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
666
+ self.gradient_checkpointing = False
667
+
668
+ # Ignore copy
669
+ def forward(
670
+ self,
671
+ inputs_embeds,
672
+ attention_mask: Optional[torch.Tensor] = None,
673
+ output_attentions: Optional[bool] = None,
674
+ output_hidden_states: Optional[bool] = None,
675
+ return_dict: Optional[bool] = None,
676
+ ) -> Union[Tuple, BaseModelOutput]:
677
+ r"""
678
+ Args:
679
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
680
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
681
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
682
+ than the model's internal embedding lookup matrix.
683
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
684
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
685
+
686
+ - 1 for tokens that are **not masked**,
687
+ - 0 for tokens that are **masked**.
688
+
689
+ [What are attention masks?](../glossary#attention-mask)
690
+ output_attentions (`bool`, *optional*):
691
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
692
+ returned tensors for more detail.
693
+ output_hidden_states (`bool`, *optional*):
694
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
695
+ for more detail.
696
+ return_dict (`bool`, *optional*):
697
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
698
+ """
699
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
700
+ output_hidden_states = (
701
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
702
+ )
703
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
704
+
705
+ encoder_states = () if output_hidden_states else None
706
+ all_attentions = () if output_attentions else None
707
+
708
+ hidden_states = inputs_embeds
709
+ for encoder_layer in self.layers:
710
+ if output_hidden_states:
711
+ encoder_states = encoder_states + (hidden_states,)
712
+ if self.gradient_checkpointing and self.training:
713
+ layer_outputs = self._gradient_checkpointing_func(
714
+ encoder_layer.__call__,
715
+ hidden_states,
716
+ attention_mask,
717
+ output_attentions,
718
+ )
719
+ else:
720
+ layer_outputs = encoder_layer(
721
+ hidden_states,
722
+ attention_mask,
723
+ output_attentions=output_attentions,
724
+ )
725
+
726
+ hidden_states = layer_outputs[0]
727
+
728
+ if output_attentions:
729
+ all_attentions = all_attentions + (layer_outputs[1],)
730
+
731
+ if output_hidden_states:
732
+ encoder_states = encoder_states + (hidden_states,)
733
+
734
+ if not return_dict:
735
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
736
+ return BaseModelOutput(
737
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
738
+ )
739
+
740
+
741
+ class SiglipTextTransformer(nn.Module):
742
+ def __init__(self, config: SiglipTextConfig):
743
+ super().__init__()
744
+ self.config = config
745
+ embed_dim = config.hidden_size
746
+ self.embeddings = SiglipTextEmbeddings(config)
747
+ self.encoder = SiglipEncoder(config)
748
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
749
+
750
+ self.head = nn.Linear(embed_dim, embed_dim)
751
+
752
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
753
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
754
+ def forward(
755
+ self,
756
+ input_ids: Optional[torch.Tensor] = None,
757
+ attention_mask: Optional[torch.Tensor] = None,
758
+ position_ids: Optional[torch.Tensor] = None,
759
+ output_attentions: Optional[bool] = None,
760
+ output_hidden_states: Optional[bool] = None,
761
+ return_dict: Optional[bool] = None,
762
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
763
+ r"""
764
+ Returns:
765
+
766
+ """
767
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
768
+ output_hidden_states = (
769
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
770
+ )
771
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
772
+
773
+ if input_ids is None:
774
+ raise ValueError("You have to specify input_ids")
775
+
776
+ input_shape = input_ids.size()
777
+ input_ids = input_ids.view(-1, input_shape[-1])
778
+
779
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
780
+
781
+ # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
782
+ # expand attention_mask
783
+ if attention_mask is not None:
784
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
785
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
786
+
787
+ encoder_outputs = self.encoder(
788
+ inputs_embeds=hidden_states,
789
+ attention_mask=attention_mask,
790
+ output_attentions=output_attentions,
791
+ output_hidden_states=output_hidden_states,
792
+ return_dict=return_dict,
793
+ )
794
+
795
+ last_hidden_state = encoder_outputs[0]
796
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
797
+
798
+ # Assuming "sticky" EOS tokenization, last token is always EOS.
799
+ pooled_output = last_hidden_state[:, -1, :]
800
+ pooled_output = self.head(pooled_output)
801
+
802
+ if not return_dict:
803
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
804
+
805
+ return BaseModelOutputWithPooling(
806
+ last_hidden_state=last_hidden_state,
807
+ pooler_output=pooled_output,
808
+ hidden_states=encoder_outputs.hidden_states,
809
+ attentions=encoder_outputs.attentions,
810
+ )
811
+
812
+
813
+ @add_start_docstrings(
814
+ """The text model from SigLIP without any head or projection on top.""",
815
+ SIGLIP_START_DOCSTRING,
816
+ )
817
+ class SiglipTextModel(SiglipPreTrainedModel):
818
+ config_class = SiglipTextConfig
819
+
820
+ _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
821
+
822
+ def __init__(self, config: SiglipTextConfig):
823
+ super().__init__(config)
824
+ self.text_model = SiglipTextTransformer(config)
825
+ # Initialize weights and apply final processing
826
+ self.post_init()
827
+
828
+ def get_input_embeddings(self) -> nn.Module:
829
+ return self.text_model.embeddings.token_embedding
830
+
831
+ def set_input_embeddings(self, value):
832
+ self.text_model.embeddings.token_embedding = value
833
+
834
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
835
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
836
+ def forward(
837
+ self,
838
+ input_ids: Optional[torch.Tensor] = None,
839
+ attention_mask: Optional[torch.Tensor] = None,
840
+ position_ids: Optional[torch.Tensor] = None,
841
+ output_attentions: Optional[bool] = None,
842
+ output_hidden_states: Optional[bool] = None,
843
+ return_dict: Optional[bool] = None,
844
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
845
+ r"""
846
+ Returns:
847
+
848
+ Examples:
849
+
850
+ ```python
851
+ >>> from transformers import AutoTokenizer, SiglipTextModel
852
+
853
+ >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
854
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
855
+
856
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
857
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
858
+
859
+ >>> outputs = model(**inputs)
860
+ >>> last_hidden_state = outputs.last_hidden_state
861
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
862
+ ```"""
863
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
864
+
865
+ return self.text_model(
866
+ input_ids=input_ids,
867
+ attention_mask=attention_mask,
868
+ position_ids=position_ids,
869
+ output_attentions=output_attentions,
870
+ output_hidden_states=output_hidden_states,
871
+ return_dict=return_dict,
872
+ )
873
+
874
+
875
+ class SiglipVisionTransformer(nn.Module):
876
+ def __init__(self, config: SiglipVisionConfig):
877
+ super().__init__()
878
+ self.config = config
879
+ embed_dim = config.hidden_size
880
+
881
+ self.embeddings = SiglipVisionEmbeddings(config)
882
+ self.encoder = SiglipEncoder(config)
883
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
884
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
885
+
886
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
887
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
888
+ def forward(
889
+ self,
890
+ pixel_values,
891
+ output_attentions: Optional[bool] = None,
892
+ output_hidden_states: Optional[bool] = None,
893
+ return_dict: Optional[bool] = None,
894
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
895
+ r"""
896
+ Returns:
897
+
898
+ """
899
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
900
+ output_hidden_states = (
901
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
902
+ )
903
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
904
+
905
+ hidden_states = self.embeddings(pixel_values)
906
+
907
+ encoder_outputs = self.encoder(
908
+ inputs_embeds=hidden_states,
909
+ output_attentions=output_attentions,
910
+ output_hidden_states=output_hidden_states,
911
+ return_dict=return_dict,
912
+ )
913
+
914
+ last_hidden_state = encoder_outputs[0]
915
+ last_hidden_state = self.post_layernorm(last_hidden_state)
916
+
917
+ pooled_output = self.head(last_hidden_state)
918
+
919
+ if not return_dict:
920
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
921
+
922
+ return BaseModelOutputWithPooling(
923
+ last_hidden_state=last_hidden_state,
924
+ pooler_output=pooled_output,
925
+ hidden_states=encoder_outputs.hidden_states,
926
+ attentions=encoder_outputs.attentions,
927
+ )
928
+
929
+
930
+ class SiglipMultiheadAttentionPoolingHead(nn.Module):
931
+ """Multihead Attention Pooling."""
932
+
933
+ def __init__(self, config: SiglipVisionConfig):
934
+ super().__init__()
935
+
936
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
937
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
938
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
939
+ self.mlp = SiglipMLP(config)
940
+
941
+ def forward(self, hidden_state):
942
+ batch_size = hidden_state.shape[0]
943
+ probe = self.probe.repeat(batch_size, 1, 1)
944
+
945
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
946
+
947
+ residual = hidden_state
948
+ hidden_state = self.layernorm(hidden_state)
949
+ hidden_state = residual + self.mlp(hidden_state)
950
+
951
+ return hidden_state[:, 0]
952
+
953
+
954
+ @add_start_docstrings(
955
+ """The vision model from SigLIP without any head or projection on top.""",
956
+ SIGLIP_START_DOCSTRING,
957
+ )
958
+ class SiglipVisionModel(SiglipPreTrainedModel):
959
+ config_class = SiglipVisionConfig
960
+ main_input_name = "pixel_values"
961
+ _no_split_modules = [
962
+ "SiglipEncoderLayer",
963
+ "SiglipVisionEmbeddings",
964
+ "SiglipMultiheadAttentionPoolingHead",
965
+ ]
966
+
967
+ def __init__(self, config: SiglipVisionConfig):
968
+ super().__init__(config)
969
+
970
+ self.vision_model = SiglipVisionTransformer(config)
971
+
972
+ # Initialize weights and apply final processing
973
+ self.post_init()
974
+
975
+ def get_input_embeddings(self) -> nn.Module:
976
+ return self.vision_model.embeddings.patch_embedding
977
+
978
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
979
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
980
+ def forward(
981
+ self,
982
+ pixel_values,
983
+ output_attentions: Optional[bool] = None,
984
+ output_hidden_states: Optional[bool] = None,
985
+ return_dict: Optional[bool] = None,
986
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
987
+ r"""
988
+ Returns:
989
+
990
+ Examples:
991
+
992
+ ```python
993
+ >>> from PIL import Image
994
+ >>> import requests
995
+ >>> from transformers import AutoProcessor, SiglipVisionModel
996
+
997
+ >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
998
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
999
+
1000
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1001
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1002
+
1003
+ >>> inputs = processor(images=image, return_tensors="pt")
1004
+
1005
+ >>> outputs = model(**inputs)
1006
+ >>> last_hidden_state = outputs.last_hidden_state
1007
+ >>> pooled_output = outputs.pooler_output # pooled features
1008
+ ```"""
1009
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1010
+
1011
+ return self.vision_model(
1012
+ pixel_values=pixel_values,
1013
+ output_attentions=output_attentions,
1014
+ output_hidden_states=output_hidden_states,
1015
+ return_dict=return_dict,
1016
+ )
1017
+
1018
+
1019
+ @add_start_docstrings(SIGLIP_START_DOCSTRING)
1020
+ class SiglipModel(SiglipPreTrainedModel):
1021
+ config_class = SiglipConfig
1022
+
1023
+ def __init__(self, config: SiglipConfig):
1024
+ super().__init__(config)
1025
+
1026
+ if not isinstance(config.text_config, SiglipTextConfig):
1027
+ raise ValueError(
1028
+ "config.text_config is expected to be of type SiglipTextConfig but is of type"
1029
+ f" {type(config.text_config)}."
1030
+ )
1031
+
1032
+ if not isinstance(config.vision_config, SiglipVisionConfig):
1033
+ raise ValueError(
1034
+ "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
1035
+ f" {type(config.vision_config)}."
1036
+ )
1037
+
1038
+ text_config = config.text_config
1039
+ vision_config = config.vision_config
1040
+
1041
+ self.text_model = SiglipTextTransformer(text_config)
1042
+ self.vision_model = SiglipVisionTransformer(vision_config)
1043
+
1044
+ self.logit_scale = nn.Parameter(torch.randn(1))
1045
+ self.logit_bias = nn.Parameter(torch.randn(1))
1046
+
1047
+ # Initialize weights and apply final processing
1048
+ self.post_init()
1049
+
1050
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1051
+ def get_text_features(
1052
+ self,
1053
+ input_ids: Optional[torch.Tensor] = None,
1054
+ attention_mask: Optional[torch.Tensor] = None,
1055
+ position_ids: Optional[torch.Tensor] = None,
1056
+ output_attentions: Optional[bool] = None,
1057
+ output_hidden_states: Optional[bool] = None,
1058
+ return_dict: Optional[bool] = None,
1059
+ ) -> torch.FloatTensor:
1060
+ r"""
1061
+ Returns:
1062
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1063
+ applying the projection layer to the pooled output of [`SiglipTextModel`].
1064
+
1065
+ Examples:
1066
+
1067
+ ```python
1068
+ >>> from transformers import AutoTokenizer, AutoModel
1069
+ >>> import torch
1070
+
1071
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1072
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1073
+
1074
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1075
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1076
+ >>> with torch.no_grad():
1077
+ transformers. text_features = model.get_text_features(**inputs)
1078
+ ```"""
1079
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1080
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1081
+ output_hidden_states = (
1082
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1083
+ )
1084
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1085
+
1086
+ text_outputs = self.text_model(
1087
+ input_ids=input_ids,
1088
+ attention_mask=attention_mask,
1089
+ position_ids=position_ids,
1090
+ output_attentions=output_attentions,
1091
+ output_hidden_states=output_hidden_states,
1092
+ return_dict=return_dict,
1093
+ )
1094
+
1095
+ pooled_output = text_outputs[1]
1096
+
1097
+ return pooled_output
1098
+
1099
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1100
+ def get_image_features(
1101
+ self,
1102
+ pixel_values: Optional[torch.FloatTensor] = None,
1103
+ output_attentions: Optional[bool] = None,
1104
+ output_hidden_states: Optional[bool] = None,
1105
+ return_dict: Optional[bool] = None,
1106
+ ) -> torch.FloatTensor:
1107
+ r"""
1108
+ Returns:
1109
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1110
+ applying the projection layer to the pooled output of [`SiglipVisionModel`].
1111
+
1112
+ Examples:
1113
+
1114
+ ```python
1115
+ >>> from PIL import Image
1116
+ >>> import requests
1117
+ >>> from transformers import AutoProcessor, AutoModel
1118
+ >>> import torch
1119
+
1120
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1121
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1122
+
1123
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1124
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1125
+
1126
+ >>> inputs = processor(images=image, return_tensors="pt")
1127
+
1128
+ >>> with torch.no_grad():
1129
+ transformers. image_features = model.get_image_features(**inputs)
1130
+ ```"""
1131
+ # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1132
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1133
+ output_hidden_states = (
1134
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1135
+ )
1136
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1137
+
1138
+ vision_outputs = self.vision_model(
1139
+ pixel_values=pixel_values,
1140
+ output_attentions=output_attentions,
1141
+ output_hidden_states=output_hidden_states,
1142
+ return_dict=return_dict,
1143
+ )
1144
+
1145
+ pooled_output = vision_outputs[1]
1146
+
1147
+ return pooled_output
1148
+
1149
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1150
+ @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
1151
+ def forward(
1152
+ self,
1153
+ input_ids: Optional[torch.LongTensor] = None,
1154
+ pixel_values: Optional[torch.FloatTensor] = None,
1155
+ attention_mask: Optional[torch.Tensor] = None,
1156
+ position_ids: Optional[torch.LongTensor] = None,
1157
+ return_loss: Optional[bool] = None,
1158
+ output_attentions: Optional[bool] = None,
1159
+ output_hidden_states: Optional[bool] = None,
1160
+ return_dict: Optional[bool] = None,
1161
+ ) -> Union[Tuple, SiglipOutput]:
1162
+ r"""
1163
+ Returns:
1164
+
1165
+ Examples:
1166
+
1167
+ ```python
1168
+ >>> from PIL import Image
1169
+ >>> import requests
1170
+ >>> from transformers import AutoProcessor, AutoModel
1171
+ >>> import torch
1172
+
1173
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1174
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1175
+
1176
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1177
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1178
+
1179
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
1180
+ >>> # important: we pass `padding=max_length` since the model was trained with this
1181
+ >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
1182
+
1183
+ >>> with torch.no_grad():
1184
+ transformers. outputs = model(**inputs)
1185
+
1186
+ >>> logits_per_image = outputs.logits_per_image
1187
+ >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
1188
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
1189
+ 31.9% that image 0 is 'a photo of 2 cats'
1190
+ ```"""
1191
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1192
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1193
+ output_hidden_states = (
1194
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1195
+ )
1196
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1197
+
1198
+ vision_outputs = self.vision_model(
1199
+ pixel_values=pixel_values,
1200
+ output_attentions=output_attentions,
1201
+ output_hidden_states=output_hidden_states,
1202
+ return_dict=return_dict,
1203
+ )
1204
+
1205
+ text_outputs = self.text_model(
1206
+ input_ids=input_ids,
1207
+ attention_mask=attention_mask,
1208
+ position_ids=position_ids,
1209
+ output_attentions=output_attentions,
1210
+ output_hidden_states=output_hidden_states,
1211
+ return_dict=return_dict,
1212
+ )
1213
+
1214
+ image_embeds = vision_outputs[1]
1215
+ text_embeds = text_outputs[1]
1216
+
1217
+ # normalized features
1218
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1219
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1220
+
1221
+ # cosine similarity as logits
1222
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias
1223
+ logits_per_image = logits_per_text.t()
1224
+
1225
+ loss = None
1226
+ if return_loss:
1227
+ raise NotImplementedError("SigLIP loss to be implemented")
1228
+
1229
+ if not return_dict:
1230
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1231
+ return ((loss,) + output) if loss is not None else output
1232
+
1233
+ return SiglipOutput(
1234
+ loss=loss,
1235
+ logits_per_image=logits_per_image,
1236
+ logits_per_text=logits_per_text,
1237
+ text_embeds=text_embeds,
1238
+ image_embeds=image_embeds,
1239
+ text_model_output=text_outputs,
1240
+ vision_model_output=vision_outputs,
1241
+ )
monitor.txt ADDED
The diff for this file is too large to render. See raw diff
 
multi_backbone_channel_concatenation_encoder.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Eagle2
3
+ # Copyright (c) 2025 NVIDIA
4
+ # Licensed under The Apache License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import torch, os
8
+ import torch.nn as nn
9
+ from torch.utils.checkpoint import checkpoint
10
+
11
+ from .siglip_vision_tower import SiglipVisionTower
12
+
13
+ import torch.nn.functional as F
14
+ from torch.nn.init import trunc_normal_
15
+ from copy import deepcopy
16
+ import random
17
+ import math
18
+
19
+ class MultiBackboneChannelConcatenationVisionTower(nn.Module):
20
+ def __init__(self,
21
+ vision_tower,
22
+ args,
23
+ grid_size=32,
24
+ convnext_img_size=1024,
25
+ normalize_type=None, raw_config=None):
26
+
27
+ super().__init__()
28
+
29
+ self.is_loaded = False
30
+ self.grid_size = grid_size
31
+ self.num_tokens = self.grid_size ** 2
32
+ self.normalize_type = args.normalize_type
33
+ self.moe_version_type = args.moe_version_type
34
+ self.raw_config = raw_config
35
+ print("moe_version_type: ", self.moe_version_type)
36
+ assert self.moe_version_type in [None, 'all_tiling', 'seq_concat', 'feat_concat', 'convnext_512_siglip_448'], f"Unknown self.moe_version_type: {self.moe_version_type}"
37
+
38
+ vision_tower_name_list = vision_tower.split(";")
39
+ self.input_image_size = 1024
40
+ self.convnext_img_size = convnext_img_size
41
+ self.load_vision_towers(vision_tower_name_list, args)
42
+
43
+
44
+ def load_vision_towers(self, vision_tower_name_list, args):
45
+ self.vision_towers = nn.ModuleList()
46
+
47
+ freeze_backbone_list = args.freeze_backbones # note this is a str
48
+ if freeze_backbone_list is not None and len(freeze_backbone_list) > 0:
49
+ print("The frozen backbones: ", freeze_backbone_list)
50
+ else:
51
+ # make it a blank str
52
+ freeze_backbone_list = ""
53
+
54
+ for name in vision_tower_name_list:
55
+
56
+ ## ConvNeXt
57
+ if name == 'convnext-1024':
58
+ convnext_args = deepcopy(args)
59
+
60
+ convnext_args.freeze_vision = False
61
+ if 'convnext-1024' in freeze_backbone_list:
62
+ convnext_args.freeze_vision = True
63
+
64
+ from .convnext_encoder import ConvNextVisionTower
65
+ convnext_args.input_image_size = self.convnext_img_size
66
+ convnext_vision_tower = args.vision_tower_convnext_path
67
+ convnext_vision_tower = ConvNextVisionTower(convnext_vision_tower,
68
+ convnext_args, delay_load=args.delay_load, normalize_type=self.normalize_type)
69
+ convnext_vision_tower.load_model()
70
+ self.vision_towers.append(convnext_vision_tower)
71
+
72
+ ## PaliSigLIP
73
+ elif name == 'palisiglip':
74
+ palisiglip_args = deepcopy(args)
75
+ palisiglip_args.input_image_size = 448
76
+
77
+ palisiglip_args.freeze_vision = False
78
+ if 'palisiglip' in freeze_backbone_list:
79
+ palisiglip_args.freeze_vision = True
80
+
81
+ palisiglip_vision_tower = SiglipVisionTower(args.vision_tower_siglip_path, palisiglip_args, delay_load=args.delay_load, raw_config=self.raw_config)
82
+
83
+ palisiglip_vision_tower.load_model()
84
+ self.vision_towers.append(palisiglip_vision_tower)
85
+
86
+ # Set the image processor
87
+ self.image_processor = None
88
+ self.is_loaded = True
89
+
90
+ def load_model(self):
91
+ assert self.is_loaded, "All the vision encoders should be loaded during initialization!"
92
+
93
+ def forward(self, x):
94
+ # x is a Tensor if moe_version_type is None or 'all_tiling'
95
+ # else is a tuple(Tensor, Tensor)
96
+ if self.moe_version_type in [None, 'all_tiling']:
97
+ # The default pipeline
98
+ features = []
99
+ image_input_size = x.shape[2]
100
+ assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})"
101
+ for vision_tower in self.vision_towers:
102
+
103
+ if vision_tower.input_image_size != image_input_size:
104
+ resized_x = F.interpolate(x.float(),
105
+ size=(vision_tower.input_image_size, vision_tower.input_image_size),
106
+ mode='bilinear',
107
+ align_corners=True).to(dtype=x.dtype)
108
+ else:
109
+ resized_x = x
110
+
111
+ feature = vision_tower(resized_x)
112
+
113
+ if len(feature.shape) == 3: # b, n, c
114
+ b, n, c = feature.shape
115
+ if n == self.num_tokens:
116
+ features.append(feature)
117
+ continue
118
+ w = h = int(n**0.5)
119
+ feature = feature.transpose(1,2).reshape(b, c, h, w)
120
+ else:
121
+ b, c, h, w = feature.shape
122
+
123
+ if w != self.grid_size:
124
+ feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype)
125
+ features.append(feature.flatten(2,3).transpose(1,2))
126
+
127
+ features = torch.cat(features, dim=-1)
128
+ elif self.moe_version_type == 'convnext_512_siglip_448':
129
+ features = {}
130
+ image_input_size = x.shape[2]
131
+ assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})"
132
+ for vision_tower in self.vision_towers:
133
+
134
+ if vision_tower.input_image_size != image_input_size:
135
+ resized_x = F.interpolate(x.float(),
136
+ size=(vision_tower.input_image_size, vision_tower.input_image_size),
137
+ mode='bilinear',
138
+ align_corners=True).to(dtype=x.dtype)
139
+ else:
140
+ resized_x = x
141
+
142
+ feature = vision_tower(resized_x)
143
+
144
+ # if len(feature.shape) == 3: # b, n, c
145
+ # b, n, c = feature.shape
146
+ # if n == self.num_tokens:
147
+ # features.append(feature)
148
+ # continue
149
+ # w = h = int(n**0.5)
150
+ # feature = feature.transpose(1,2).reshape(b, c, h, w)
151
+ # else:
152
+ # b, c, h, w = feature.shape
153
+ features[vision_tower.name] = feature
154
+
155
+ else:
156
+ assert isinstance(x, dict), "x is expected to be a dict but {}".format(type(x))
157
+ pixel_values = x['pixel_values']
158
+ num_patches = x['num_patches'] # num patch of paddings token in texts
159
+
160
+ # calculated the real image patches
161
+ if self.moe_version_type == 'seq_concat':
162
+ image_in_num_patches = [i-1 for i in num_patches]
163
+ else:
164
+ image_in_num_patches = [i for i in num_patches]
165
+
166
+
167
+ assert sum(image_in_num_patches) == pixel_values.size(0), "sum(image_in_num_patches) ({}) != pixel_values.size(0) ({})".format(sum(image_in_num_patches), pixel_values.size(0))
168
+
169
+ # find the thubnail image id
170
+ thumbnail_image_id = torch.cumsum(torch.tensor(image_in_num_patches).to(pixel_values.device), 0) - 1
171
+ image_no_tiling = pixel_values[thumbnail_image_id]
172
+
173
+ # By default, we use the 1st vision_tower for x, others for x_nt
174
+ features = []
175
+ for layer_id, vision_tower in enumerate(self.vision_towers):
176
+ if layer_id == 0:
177
+ x = pixel_values
178
+ else:
179
+ x = image_no_tiling
180
+
181
+ if vision_tower.input_image_size != self.input_image_size:
182
+ resized_x = F.interpolate(x.float(),
183
+ size=(vision_tower.input_image_size, vision_tower.input_image_size),
184
+ mode='bilinear',
185
+ align_corners=True).to(dtype=x.dtype)
186
+ else:
187
+ resized_x = x
188
+
189
+ feature = vision_tower(resized_x)
190
+ if len(feature.shape) == 3: # b, n, c
191
+ b, n, c = feature.shape
192
+ if n == self.num_tokens:
193
+ features.append(feature)
194
+ continue
195
+
196
+ w = h = int(n**0.5)
197
+ feature = feature.transpose(1,2).reshape(b, c, h, w)
198
+ else:
199
+ b, c, h, w = feature.shape
200
+
201
+ if w != self.grid_size:
202
+ feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype)
203
+ features.append(feature.flatten(2,3).transpose(1,2))
204
+
205
+ clip_embeds = features[0]
206
+ if len(features) <= 1:
207
+ no_tiling_embeds = None
208
+ else:
209
+ no_tiling_embeds = torch.cat(features[1:], dim=-1)
210
+
211
+ if self.moe_version_type == 'feat_concat':
212
+ # concat thumbnail images features together
213
+ clip_thumbnail_embeds = clip_embeds[thumbnail_image_id]
214
+ if no_tiling_embeds is not None:
215
+ no_tiling_embeds = torch.cat([clip_thumbnail_embeds, no_tiling_embeds], dim=-1)
216
+ else:
217
+ no_tiling_embeds = clip_thumbnail_embeds
218
+
219
+ # extra patch featureas
220
+ clip_embeds_mask = ~torch.isin(torch.arange(clip_embeds.shape[0]).to(clip_embeds.device), thumbnail_image_id)
221
+ clip_embeds = clip_embeds[clip_embeds_mask]
222
+
223
+
224
+ features = {
225
+ 'clip_embeds': clip_embeds,
226
+ 'no_tiling_embeds': no_tiling_embeds,
227
+ 'num_patches': num_patches
228
+ }
229
+
230
+ # features is a Tensor if not clip_tiling_only
231
+
232
+ return features
233
+
234
+ @property
235
+ def dummy_feature(self):
236
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
237
+
238
+ @property
239
+ def dtype(self):
240
+ return next(self.clip_vision_tower.parameters()).dtype
241
+
242
+ @property
243
+ def device(self):
244
+ return next(self.clip_vision_tower.parameters()).device
245
+
246
+ @property
247
+ def config(self):
248
+ assert NotImplementedError
249
+ pass
250
+
251
+ @property
252
+ def hidden_size(self):
253
+ if self.moe_version_type == 'convnext_512_siglip_448':
254
+ res = {}
255
+ for vision_tower in self.vision_towers:
256
+ res[vision_tower.name] = vision_tower.hidden_size
257
+ return res
258
+ else:
259
+ return sum([_.hidden_size for _ in self.vision_towers])
260
+
261
+ @property
262
+ def num_patches(self):
263
+ return self.num_tokens
264
+
265
+
266
+
multi_backbone_channel_concatentation_model.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Eagle2
3
+ # Copyright (c) 2025 NVIDIA
4
+ # Licensed under The Apache License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+
8
+ import torch.nn as nn
9
+
10
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
11
+ from typing import Optional, Tuple, Union
12
+
13
+ from .multi_backbone_channel_concatenation_encoder import MultiBackboneChannelConcatenationVisionTower
14
+ from .configuration_multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModelConfig
15
+
16
+
17
+ class MultiBackboneChannelConcatenationVisionModel(nn.Module):
18
+
19
+ """
20
+ A vision model wrapper that concatenates channels from multiple backbones.
21
+
22
+ Args:
23
+ config (MultiBackboneChannelConcatenationVisionModelConfig): The configuration for the model.
24
+
25
+ Attributes:
26
+ vision_model (MultiBackboneChannelConcatenationVisionTower): The vision tower that performs the channel concatenation.
27
+
28
+ Notes:
29
+ **The class is not inherited from the PreTrainedModel in transformers**
30
+
31
+ """
32
+
33
+ config_class = MultiBackboneChannelConcatenationVisionModelConfig
34
+ main_input_name = "pixel_values"
35
+
36
+ def __init__(self, config: MultiBackboneChannelConcatenationVisionModelConfig, raw_config):
37
+ super().__init__()
38
+
39
+ self.vision_model = MultiBackboneChannelConcatenationVisionTower(
40
+ vision_tower=config.vision_tower,
41
+ args=config,
42
+ grid_size=config.grid_size,
43
+ convnext_img_size=config.convnext_img_size,
44
+ normalize_type=config.normalize_type,
45
+ raw_config=raw_config
46
+ )
47
+
48
+
49
+ def get_input_embeddings(self):
50
+ # You might need to adjust this depending on how you want to handle input embeddings
51
+ return self.vision_model.vision_towers[0].get_input_embeddings()
52
+
53
+ def forward(
54
+ self,
55
+ pixel_values,
56
+ return_dict: Optional[bool] = True,
57
+ output_hidden_states: Optional[bool] = False,
58
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
59
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
60
+
61
+ assert return_dict is True, "We only support return_dict"
62
+ assert output_hidden_states is False, "We do not support output_hidden_states"
63
+
64
+ features = self.vision_model(pixel_values)
65
+
66
+ # We only supports features as model outputs
67
+ return BaseModelOutputWithPooling(
68
+ last_hidden_state=features,
69
+ pooler_output=None,
70
+ hidden_states=None,
71
+ attentions=None,
72
+ )
73
+
74
+ @property
75
+ def dummy_feature(self):
76
+ return self.vision_model.dummy_feature
77
+
78
+ @property
79
+ def dtype(self):
80
+ return self.vision_model.dtype
81
+
82
+ @property
83
+ def device(self):
84
+ return self.vision_model.device
85
+
86
+ @property
87
+ def config(self):
88
+ return self.vision_model.config
89
+
90
+ @property
91
+ def hidden_size(self):
92
+ return self.vision_model.hidden_size
93
+
94
+ @property
95
+ def num_patches(self):
96
+ return self.vision_model.num_patches
siglip_vision_tower.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from .modeling_siglip import SiglipVisionModel
6
+ from .configuration_siglip import SiglipVisionConfig
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from typing import List, Optional
12
+ import os
13
+
14
+ class SiglipVisionTower(nn.Module):
15
+ # We use the same wrapper as the default clip encoder.
16
+ # See `clip_encoder.py` in the same folder
17
+ def __init__(self, vision_tower, args, delay_load=False, raw_config=None):
18
+ super().__init__()
19
+
20
+ self.is_loaded = False
21
+ self.freeze_vision=args.freeze_vision
22
+ self.input_image_size=args.input_image_size
23
+ self.vision_tower_name = vision_tower
24
+ self.select_layer = args.mm_vision_select_layer
25
+ self.name = 'siglip'
26
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
27
+ self.delay_load = delay_load
28
+ self.raw_config = raw_config
29
+ if not delay_load:
30
+ self.load_model()
31
+ else:
32
+ if os.path.isfile(self.vision_tower_name):
33
+ self.cfg_only = SiglipVisionConfig.from_pretrained(self.vision_tower_name, local_files_only=True)
34
+ else:
35
+ self.cfg_only = SiglipVisionConfig(**self.raw_config.vision_config.siglip_vision_config)
36
+
37
+
38
+ def load_model(self):
39
+ if self.is_loaded:
40
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
41
+ return
42
+
43
+ # self.image_processor = SiglipImageProcessor(size=1024)
44
+ # self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, local_files_only=True, torch_dtype=torch.bfloat16)
45
+ if self.delay_load:
46
+ # cfg = SiglipVisionConfig.from_pretrained(self.vision_tower_name, local_files_only=True)
47
+ self.vision_tower = SiglipVisionModel(self.cfg_only)
48
+ else:
49
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, local_files_only=True)
50
+
51
+ if self.freeze_vision:
52
+ self.vision_tower.requires_grad_(False)
53
+
54
+ self.vision_tower.vision_model.encoder.gradient_checkpointing = True
55
+ self.is_loaded = True
56
+
57
+ def forward(self, images):
58
+ return self.vision_tower(
59
+ pixel_values=images,
60
+ output_hidden_states=False,
61
+ return_dict=True).last_hidden_state
62
+
63
+
64
+ @property
65
+ def dummy_feature(self):
66
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
67
+
68
+ @property
69
+ def dtype(self):
70
+ return self.vision_tower.dtype
71
+
72
+ @property
73
+ def device(self):
74
+ return self.vision_tower.device
75
+
76
+ @property
77
+ def config(self):
78
+ if self.is_loaded:
79
+ return self.vision_tower.config
80
+ else:
81
+ return self.cfg_only
82
+
83
+ @property
84
+ def hidden_size(self):
85
+ return self.config.hidden_size
86
+
87
+ @property
88
+ def num_patches_per_side(self):
89
+ return self.config.image_size // self.config.patch_size
90
+
91
+ @property
92
+ def num_patches(self):
93
+ return (self.config.image_size // self.config.patch_size) ** 2
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|begin_of_text|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|eot_id|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|finetune_right_pad_id|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenization_qwen2.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Qwen2."""
16
+
17
+ import json
18
+ import os
19
+ import unicodedata
20
+ from functools import lru_cache
21
+ from typing import Optional, Tuple
22
+
23
+ import regex as re
24
+
25
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
26
+ from transformers.utils import logging
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ VOCAB_FILES_NAMES = {
32
+ "vocab_file": "vocab.json",
33
+ "merges_file": "merges.txt",
34
+ }
35
+
36
+ PRETRAINED_VOCAB_FILES_MAP = {
37
+ "vocab_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/vocab.json"},
38
+ "merges_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/merges.txt"},
39
+ }
40
+
41
+ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
42
+
43
+ PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
44
+
45
+
46
+ @lru_cache()
47
+ # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
48
+ def bytes_to_unicode():
49
+ """
50
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
51
+ characters the bpe code barfs on.
52
+
53
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
54
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
55
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
56
+ tables between utf-8 bytes and unicode strings.
57
+ """
58
+ bs = (
59
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
60
+ )
61
+ cs = bs[:]
62
+ n = 0
63
+ for b in range(2**8):
64
+ if b not in bs:
65
+ bs.append(b)
66
+ cs.append(2**8 + n)
67
+ n += 1
68
+ cs = [chr(n) for n in cs]
69
+ return dict(zip(bs, cs))
70
+
71
+
72
+ # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
73
+ def get_pairs(word):
74
+ """
75
+ Return set of symbol pairs in a word.
76
+
77
+ Word is represented as tuple of symbols (symbols being variable-length strings).
78
+ """
79
+ pairs = set()
80
+ prev_char = word[0]
81
+ for char in word[1:]:
82
+ pairs.add((prev_char, char))
83
+ prev_char = char
84
+ return pairs
85
+
86
+
87
+ class Qwen2Tokenizer(PreTrainedTokenizer):
88
+ """
89
+ Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
90
+
91
+ Same with GPT2Tokenzier, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
92
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
93
+
94
+ ```python
95
+ >>> from transformers import Qwen2Tokenizer
96
+
97
+ >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
98
+ >>> tokenizer("Hello world")["input_ids"]
99
+ [9707, 1879]
100
+
101
+ >>> tokenizer(" Hello world")["input_ids"]
102
+ [21927, 1879]
103
+ ```
104
+ This is expected.
105
+
106
+ You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
107
+
108
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
109
+ this superclass for more information regarding those methods.
110
+
111
+ Args:
112
+ vocab_file (`str`):
113
+ Path to the vocabulary file.
114
+ merges_file (`str`):
115
+ Path to the merges file.
116
+ errors (`str`, *optional*, defaults to `"replace"`):
117
+ Paradigm to follow when decoding bytes to UTF-8. See
118
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
119
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
120
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
121
+ token instead.
122
+ bos_token (`str`, *optional*):
123
+ The beginning of sequence token. Not applicable for this tokenizer.
124
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
125
+ The end of sequence token.
126
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
127
+ The token used for padding, for example when batching sequences of different lengths.
128
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
129
+ Whether or not the model should cleanup the spaces that were added when splitting the input text during the
130
+ tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
131
+ split_special_tokens (`bool`, *optional*, defaults to `False`):
132
+ Whether or not the special tokens should be split during the tokenization process. The default behavior is
133
+ to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
134
+ ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
135
+ '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
136
+ """
137
+
138
+ vocab_files_names = VOCAB_FILES_NAMES
139
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
140
+ max_model_input_sizes = MAX_MODEL_INPUT_SIZES
141
+ model_input_names = ["input_ids", "attention_mask"]
142
+
143
+ def __init__(
144
+ self,
145
+ vocab_file,
146
+ merges_file,
147
+ errors="replace",
148
+ unk_token="<|endoftext|>",
149
+ bos_token=None,
150
+ eos_token="<|endoftext|>",
151
+ pad_token="<|endoftext|>",
152
+ clean_up_tokenization_spaces=False,
153
+ split_special_tokens=False,
154
+ **kwargs,
155
+ ):
156
+ # Qwen vocab does not contain control tokens; added tokens need to be special
157
+ bos_token = (
158
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
159
+ if isinstance(bos_token, str)
160
+ else bos_token
161
+ )
162
+ eos_token = (
163
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
164
+ if isinstance(eos_token, str)
165
+ else eos_token
166
+ )
167
+ unk_token = (
168
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
169
+ if isinstance(unk_token, str)
170
+ else unk_token
171
+ )
172
+ pad_token = (
173
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
174
+ if isinstance(pad_token, str)
175
+ else pad_token
176
+ )
177
+
178
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
179
+ self.encoder = json.load(vocab_handle)
180
+ self.decoder = {v: k for k, v in self.encoder.items()}
181
+ self.errors = errors # how to handle errors in decoding
182
+ self.byte_encoder = bytes_to_unicode()
183
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
184
+ bpe_merges = []
185
+ with open(merges_file, encoding="utf-8") as merges_handle:
186
+ for line in merges_handle:
187
+ line = line.strip()
188
+ if not line or line.startswith("#"):
189
+ continue
190
+ bpe_merges.append(tuple(line.split()))
191
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
192
+ # NOTE: the cache can grow without bound and will get really large for long running processes
193
+ # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
194
+ # not a memory leak but appears as one.
195
+ # GPT2Tokenizer has the same problem, so let's be consistent.
196
+ self.cache = {}
197
+
198
+ self.pat = re.compile(PRETOKENIZE_REGEX)
199
+
200
+ if kwargs.get("add_prefix_space", False):
201
+ logger.warning_once(
202
+ f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
203
+ )
204
+
205
+ super().__init__(
206
+ errors=errors,
207
+ bos_token=bos_token,
208
+ eos_token=eos_token,
209
+ pad_token=pad_token,
210
+ unk_token=unk_token,
211
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
212
+ split_special_tokens=split_special_tokens,
213
+ **kwargs,
214
+ )
215
+
216
+ @property
217
+ def vocab_size(self) -> int:
218
+ return len(self.encoder)
219
+
220
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
221
+ def get_vocab(self):
222
+ return dict(self.encoder, **self.added_tokens_encoder)
223
+
224
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
225
+ def bpe(self, token):
226
+ if token in self.cache:
227
+ return self.cache[token]
228
+ word = tuple(token)
229
+ pairs = get_pairs(word)
230
+
231
+ if not pairs:
232
+ return token
233
+
234
+ while True:
235
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
236
+ if bigram not in self.bpe_ranks:
237
+ break
238
+ first, second = bigram
239
+ new_word = []
240
+ i = 0
241
+ while i < len(word):
242
+ try:
243
+ j = word.index(first, i)
244
+ except ValueError:
245
+ new_word.extend(word[i:])
246
+ break
247
+ else:
248
+ new_word.extend(word[i:j])
249
+ i = j
250
+
251
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
252
+ new_word.append(first + second)
253
+ i += 2
254
+ else:
255
+ new_word.append(word[i])
256
+ i += 1
257
+ new_word = tuple(new_word)
258
+ word = new_word
259
+ if len(word) == 1:
260
+ break
261
+ else:
262
+ pairs = get_pairs(word)
263
+ word = " ".join(word)
264
+ self.cache[token] = word
265
+ return word
266
+
267
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
268
+ def _tokenize(self, text):
269
+ """Tokenize a string."""
270
+ bpe_tokens = []
271
+ for token in re.findall(self.pat, text):
272
+ token = "".join(
273
+ self.byte_encoder[b] for b in token.encode("utf-8")
274
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
275
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
276
+ return bpe_tokens
277
+
278
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
279
+ def _convert_token_to_id(self, token):
280
+ """Converts a token (str) in an id using the vocab."""
281
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
282
+
283
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
284
+ def _convert_id_to_token(self, index):
285
+ """Converts an index (integer) in a token (str) using the vocab."""
286
+ return self.decoder.get(index)
287
+
288
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
289
+ def convert_tokens_to_string(self, tokens):
290
+ """Converts a sequence of tokens (string) in a single string."""
291
+ text = "".join(tokens)
292
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
293
+ return text
294
+
295
+ def decode(
296
+ self,
297
+ token_ids,
298
+ skip_special_tokens: bool = False,
299
+ clean_up_tokenization_spaces: Optional[bool] = False,
300
+ spaces_between_special_tokens: bool = False,
301
+ **kwargs,
302
+ ) -> str:
303
+ # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
304
+ # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
305
+ return super().decode(
306
+ token_ids,
307
+ skip_special_tokens=skip_special_tokens,
308
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
309
+ spaces_between_special_tokens=spaces_between_special_tokens,
310
+ **kwargs,
311
+ )
312
+
313
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
314
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
315
+ if not os.path.isdir(save_directory):
316
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
317
+ return
318
+ vocab_file = os.path.join(
319
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
320
+ )
321
+ merge_file = os.path.join(
322
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
323
+ )
324
+
325
+ with open(vocab_file, "w", encoding="utf-8") as f:
326
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
327
+
328
+ index = 0
329
+ with open(merge_file, "w", encoding="utf-8") as writer:
330
+ writer.write("#version: 0.2\n")
331
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
332
+ if index != token_index:
333
+ logger.warning(
334
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
335
+ " Please check that the tokenizer is not corrupted!"
336
+ )
337
+ index = token_index
338
+ writer.write(" ".join(bpe_tokens) + "\n")
339
+ index += 1
340
+
341
+ return vocab_file, merge_file
342
+
343
+ def prepare_for_tokenization(self, text, **kwargs):
344
+ text = unicodedata.normalize("NFC", text)
345
+ return (text, kwargs)
tokenization_qwen2_fast.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Qwen2."""
16
+
17
+ from typing import Optional, Tuple
18
+
19
+ from transformers.tokenization_utils import AddedToken
20
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
21
+ from transformers.utils import logging
22
+ from .tokenization_qwen2 import Qwen2Tokenizer
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ VOCAB_FILES_NAMES = {
28
+ "vocab_file": "vocab.json",
29
+ "merges_file": "merges.txt",
30
+ "tokenizer_file": "tokenizer.json",
31
+ }
32
+
33
+ PRETRAINED_VOCAB_FILES_MAP = {
34
+ "vocab_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/vocab.json"},
35
+ "merges_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/merges.txt"},
36
+ "tokenizer_file": {
37
+ "qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/tokenizer.json"
38
+ },
39
+ }
40
+
41
+ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
42
+
43
+
44
+ class Qwen2TokenizerFast(PreTrainedTokenizerFast):
45
+ """
46
+ Construct a "fast" Qwen2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
47
+ Byte-Pair-Encoding.
48
+
49
+ Same with GPT2Tokenzier, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
50
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
51
+
52
+ ```python
53
+ >>> from transformers import Qwen2TokenizerFast
54
+
55
+ >>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer")
56
+ >>> tokenizer("Hello world")["input_ids"]
57
+ [9707, 1879]
58
+
59
+ >>> tokenizer(" Hello world")["input_ids"]
60
+ [21927, 1879]
61
+ ```
62
+ This is expected.
63
+
64
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
65
+ refer to this superclass for more information regarding those methods.
66
+
67
+ Args:
68
+ vocab_file (`str`, *optional*):
69
+ Path to the vocabulary file.
70
+ merges_file (`str`, *optional*):
71
+ Path to the merges file.
72
+ tokenizer_file (`str`, *optional*):
73
+ Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
74
+ contains everything needed to load the tokenizer.
75
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
76
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
77
+ token instead. Not applicable to this tokenizer.
78
+ bos_token (`str`, *optional*):
79
+ The beginning of sequence token. Not applicable for this tokenizer.
80
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
81
+ The end of sequence token.
82
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
83
+ The token used for padding, for example when batching sequences of different lengths.
84
+ """
85
+
86
+ vocab_files_names = VOCAB_FILES_NAMES
87
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
88
+ max_model_input_sizes = MAX_MODEL_INPUT_SIZES
89
+ model_input_names = ["input_ids", "attention_mask"]
90
+ slow_tokenizer_class = Qwen2Tokenizer
91
+
92
+ def __init__(
93
+ self,
94
+ vocab_file=None,
95
+ merges_file=None,
96
+ tokenizer_file=None,
97
+ unk_token="<|endoftext|>",
98
+ bos_token=None,
99
+ eos_token="<|endoftext|>",
100
+ pad_token="<|endoftext|>",
101
+ **kwargs,
102
+ ):
103
+ # We need to at least pass vocab_file and merges_file to base class
104
+ # in case a slow tokenizer needs to be initialized; other can be
105
+ # configured through files.
106
+ # following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token
107
+
108
+ bos_token = (
109
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
110
+ if isinstance(bos_token, str)
111
+ else bos_token
112
+ )
113
+ eos_token = (
114
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
115
+ if isinstance(eos_token, str)
116
+ else eos_token
117
+ )
118
+ unk_token = (
119
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
120
+ if isinstance(unk_token, str)
121
+ else unk_token
122
+ )
123
+ pad_token = (
124
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
125
+ if isinstance(pad_token, str)
126
+ else pad_token
127
+ )
128
+
129
+ super().__init__(
130
+ vocab_file,
131
+ merges_file,
132
+ tokenizer_file=tokenizer_file,
133
+ unk_token=unk_token,
134
+ bos_token=bos_token,
135
+ eos_token=eos_token,
136
+ pad_token=pad_token,
137
+ **kwargs,
138
+ )
139
+
140
+ # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
141
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
142
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
143
+ return tuple(files)
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,2152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_eos_token": false,
3
+ "added_tokens_decoder": {
4
+ "128000": {
5
+ "content": "<|begin_of_text|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "128001": {
13
+ "content": "<|end_of_text|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "128002": {
21
+ "content": "<|reserved_special_token_0|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "128003": {
29
+ "content": "<|reserved_special_token_1|>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "128004": {
37
+ "content": "<|finetune_right_pad_id|>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "128005": {
45
+ "content": "<|reserved_special_token_2|>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "128006": {
53
+ "content": "<|start_header_id|>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "128007": {
61
+ "content": "<|end_header_id|>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "128008": {
69
+ "content": "<|eom_id|>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "128009": {
77
+ "content": "<|eot_id|>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "128010": {
85
+ "content": "<|python_tag|>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "128011": {
93
+ "content": "<|reserved_special_token_3|>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "128012": {
101
+ "content": "<|reserved_special_token_4|>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "128013": {
109
+ "content": "<|reserved_special_token_5|>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "128014": {
117
+ "content": "<|reserved_special_token_6|>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "128015": {
125
+ "content": "<|reserved_special_token_7|>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "128016": {
133
+ "content": "<|reserved_special_token_8|>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "128017": {
141
+ "content": "<|reserved_special_token_9|>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "128018": {
149
+ "content": "<|reserved_special_token_10|>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "128019": {
157
+ "content": "<|reserved_special_token_11|>",
158
+ "lstrip": false,
159
+ "normalized": false,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "128020": {
165
+ "content": "<|reserved_special_token_12|>",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "128021": {
173
+ "content": "<|reserved_special_token_13|>",
174
+ "lstrip": false,
175
+ "normalized": false,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "128022": {
181
+ "content": "<|reserved_special_token_14|>",
182
+ "lstrip": false,
183
+ "normalized": false,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "128023": {
189
+ "content": "<|reserved_special_token_15|>",
190
+ "lstrip": false,
191
+ "normalized": false,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "128024": {
197
+ "content": "<|reserved_special_token_16|>",
198
+ "lstrip": false,
199
+ "normalized": false,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "128025": {
205
+ "content": "<|reserved_special_token_17|>",
206
+ "lstrip": false,
207
+ "normalized": false,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "128026": {
213
+ "content": "<|reserved_special_token_18|>",
214
+ "lstrip": false,
215
+ "normalized": false,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "128027": {
221
+ "content": "<|reserved_special_token_19|>",
222
+ "lstrip": false,
223
+ "normalized": false,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "128028": {
229
+ "content": "<|reserved_special_token_20|>",
230
+ "lstrip": false,
231
+ "normalized": false,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "128029": {
237
+ "content": "<|reserved_special_token_21|>",
238
+ "lstrip": false,
239
+ "normalized": false,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "128030": {
245
+ "content": "<|reserved_special_token_22|>",
246
+ "lstrip": false,
247
+ "normalized": false,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "128031": {
253
+ "content": "<|reserved_special_token_23|>",
254
+ "lstrip": false,
255
+ "normalized": false,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "128032": {
261
+ "content": "<|reserved_special_token_24|>",
262
+ "lstrip": false,
263
+ "normalized": false,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "128033": {
269
+ "content": "<|reserved_special_token_25|>",
270
+ "lstrip": false,
271
+ "normalized": false,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "128034": {
277
+ "content": "<|reserved_special_token_26|>",
278
+ "lstrip": false,
279
+ "normalized": false,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "128035": {
285
+ "content": "<|reserved_special_token_27|>",
286
+ "lstrip": false,
287
+ "normalized": false,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "128036": {
293
+ "content": "<|reserved_special_token_28|>",
294
+ "lstrip": false,
295
+ "normalized": false,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "128037": {
301
+ "content": "<|reserved_special_token_29|>",
302
+ "lstrip": false,
303
+ "normalized": false,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "128038": {
309
+ "content": "<|reserved_special_token_30|>",
310
+ "lstrip": false,
311
+ "normalized": false,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "128039": {
317
+ "content": "<|reserved_special_token_31|>",
318
+ "lstrip": false,
319
+ "normalized": false,
320
+ "rstrip": false,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "128040": {
325
+ "content": "<|reserved_special_token_32|>",
326
+ "lstrip": false,
327
+ "normalized": false,
328
+ "rstrip": false,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "128041": {
333
+ "content": "<|reserved_special_token_33|>",
334
+ "lstrip": false,
335
+ "normalized": false,
336
+ "rstrip": false,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "128042": {
341
+ "content": "<|reserved_special_token_34|>",
342
+ "lstrip": false,
343
+ "normalized": false,
344
+ "rstrip": false,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "128043": {
349
+ "content": "<|reserved_special_token_35|>",
350
+ "lstrip": false,
351
+ "normalized": false,
352
+ "rstrip": false,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "128044": {
357
+ "content": "<|reserved_special_token_36|>",
358
+ "lstrip": false,
359
+ "normalized": false,
360
+ "rstrip": false,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "128045": {
365
+ "content": "<|reserved_special_token_37|>",
366
+ "lstrip": false,
367
+ "normalized": false,
368
+ "rstrip": false,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "128046": {
373
+ "content": "<|reserved_special_token_38|>",
374
+ "lstrip": false,
375
+ "normalized": false,
376
+ "rstrip": false,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "128047": {
381
+ "content": "<|reserved_special_token_39|>",
382
+ "lstrip": false,
383
+ "normalized": false,
384
+ "rstrip": false,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "128048": {
389
+ "content": "<|reserved_special_token_40|>",
390
+ "lstrip": false,
391
+ "normalized": false,
392
+ "rstrip": false,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "128049": {
397
+ "content": "<|reserved_special_token_41|>",
398
+ "lstrip": false,
399
+ "normalized": false,
400
+ "rstrip": false,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "128050": {
405
+ "content": "<|reserved_special_token_42|>",
406
+ "lstrip": false,
407
+ "normalized": false,
408
+ "rstrip": false,
409
+ "single_word": false,
410
+ "special": true
411
+ },
412
+ "128051": {
413
+ "content": "<|reserved_special_token_43|>",
414
+ "lstrip": false,
415
+ "normalized": false,
416
+ "rstrip": false,
417
+ "single_word": false,
418
+ "special": true
419
+ },
420
+ "128052": {
421
+ "content": "<|reserved_special_token_44|>",
422
+ "lstrip": false,
423
+ "normalized": false,
424
+ "rstrip": false,
425
+ "single_word": false,
426
+ "special": true
427
+ },
428
+ "128053": {
429
+ "content": "<|reserved_special_token_45|>",
430
+ "lstrip": false,
431
+ "normalized": false,
432
+ "rstrip": false,
433
+ "single_word": false,
434
+ "special": true
435
+ },
436
+ "128054": {
437
+ "content": "<|reserved_special_token_46|>",
438
+ "lstrip": false,
439
+ "normalized": false,
440
+ "rstrip": false,
441
+ "single_word": false,
442
+ "special": true
443
+ },
444
+ "128055": {
445
+ "content": "<|reserved_special_token_47|>",
446
+ "lstrip": false,
447
+ "normalized": false,
448
+ "rstrip": false,
449
+ "single_word": false,
450
+ "special": true
451
+ },
452
+ "128056": {
453
+ "content": "<|reserved_special_token_48|>",
454
+ "lstrip": false,
455
+ "normalized": false,
456
+ "rstrip": false,
457
+ "single_word": false,
458
+ "special": true
459
+ },
460
+ "128057": {
461
+ "content": "<|reserved_special_token_49|>",
462
+ "lstrip": false,
463
+ "normalized": false,
464
+ "rstrip": false,
465
+ "single_word": false,
466
+ "special": true
467
+ },
468
+ "128058": {
469
+ "content": "<|reserved_special_token_50|>",
470
+ "lstrip": false,
471
+ "normalized": false,
472
+ "rstrip": false,
473
+ "single_word": false,
474
+ "special": true
475
+ },
476
+ "128059": {
477
+ "content": "<|reserved_special_token_51|>",
478
+ "lstrip": false,
479
+ "normalized": false,
480
+ "rstrip": false,
481
+ "single_word": false,
482
+ "special": true
483
+ },
484
+ "128060": {
485
+ "content": "<|reserved_special_token_52|>",
486
+ "lstrip": false,
487
+ "normalized": false,
488
+ "rstrip": false,
489
+ "single_word": false,
490
+ "special": true
491
+ },
492
+ "128061": {
493
+ "content": "<|reserved_special_token_53|>",
494
+ "lstrip": false,
495
+ "normalized": false,
496
+ "rstrip": false,
497
+ "single_word": false,
498
+ "special": true
499
+ },
500
+ "128062": {
501
+ "content": "<|reserved_special_token_54|>",
502
+ "lstrip": false,
503
+ "normalized": false,
504
+ "rstrip": false,
505
+ "single_word": false,
506
+ "special": true
507
+ },
508
+ "128063": {
509
+ "content": "<|reserved_special_token_55|>",
510
+ "lstrip": false,
511
+ "normalized": false,
512
+ "rstrip": false,
513
+ "single_word": false,
514
+ "special": true
515
+ },
516
+ "128064": {
517
+ "content": "<|reserved_special_token_56|>",
518
+ "lstrip": false,
519
+ "normalized": false,
520
+ "rstrip": false,
521
+ "single_word": false,
522
+ "special": true
523
+ },
524
+ "128065": {
525
+ "content": "<|reserved_special_token_57|>",
526
+ "lstrip": false,
527
+ "normalized": false,
528
+ "rstrip": false,
529
+ "single_word": false,
530
+ "special": true
531
+ },
532
+ "128066": {
533
+ "content": "<|reserved_special_token_58|>",
534
+ "lstrip": false,
535
+ "normalized": false,
536
+ "rstrip": false,
537
+ "single_word": false,
538
+ "special": true
539
+ },
540
+ "128067": {
541
+ "content": "<|reserved_special_token_59|>",
542
+ "lstrip": false,
543
+ "normalized": false,
544
+ "rstrip": false,
545
+ "single_word": false,
546
+ "special": true
547
+ },
548
+ "128068": {
549
+ "content": "<|reserved_special_token_60|>",
550
+ "lstrip": false,
551
+ "normalized": false,
552
+ "rstrip": false,
553
+ "single_word": false,
554
+ "special": true
555
+ },
556
+ "128069": {
557
+ "content": "<|reserved_special_token_61|>",
558
+ "lstrip": false,
559
+ "normalized": false,
560
+ "rstrip": false,
561
+ "single_word": false,
562
+ "special": true
563
+ },
564
+ "128070": {
565
+ "content": "<|reserved_special_token_62|>",
566
+ "lstrip": false,
567
+ "normalized": false,
568
+ "rstrip": false,
569
+ "single_word": false,
570
+ "special": true
571
+ },
572
+ "128071": {
573
+ "content": "<|reserved_special_token_63|>",
574
+ "lstrip": false,
575
+ "normalized": false,
576
+ "rstrip": false,
577
+ "single_word": false,
578
+ "special": true
579
+ },
580
+ "128072": {
581
+ "content": "<|reserved_special_token_64|>",
582
+ "lstrip": false,
583
+ "normalized": false,
584
+ "rstrip": false,
585
+ "single_word": false,
586
+ "special": true
587
+ },
588
+ "128073": {
589
+ "content": "<|reserved_special_token_65|>",
590
+ "lstrip": false,
591
+ "normalized": false,
592
+ "rstrip": false,
593
+ "single_word": false,
594
+ "special": true
595
+ },
596
+ "128074": {
597
+ "content": "<|reserved_special_token_66|>",
598
+ "lstrip": false,
599
+ "normalized": false,
600
+ "rstrip": false,
601
+ "single_word": false,
602
+ "special": true
603
+ },
604
+ "128075": {
605
+ "content": "<|reserved_special_token_67|>",
606
+ "lstrip": false,
607
+ "normalized": false,
608
+ "rstrip": false,
609
+ "single_word": false,
610
+ "special": true
611
+ },
612
+ "128076": {
613
+ "content": "<|reserved_special_token_68|>",
614
+ "lstrip": false,
615
+ "normalized": false,
616
+ "rstrip": false,
617
+ "single_word": false,
618
+ "special": true
619
+ },
620
+ "128077": {
621
+ "content": "<|reserved_special_token_69|>",
622
+ "lstrip": false,
623
+ "normalized": false,
624
+ "rstrip": false,
625
+ "single_word": false,
626
+ "special": true
627
+ },
628
+ "128078": {
629
+ "content": "<|reserved_special_token_70|>",
630
+ "lstrip": false,
631
+ "normalized": false,
632
+ "rstrip": false,
633
+ "single_word": false,
634
+ "special": true
635
+ },
636
+ "128079": {
637
+ "content": "<|reserved_special_token_71|>",
638
+ "lstrip": false,
639
+ "normalized": false,
640
+ "rstrip": false,
641
+ "single_word": false,
642
+ "special": true
643
+ },
644
+ "128080": {
645
+ "content": "<|reserved_special_token_72|>",
646
+ "lstrip": false,
647
+ "normalized": false,
648
+ "rstrip": false,
649
+ "single_word": false,
650
+ "special": true
651
+ },
652
+ "128081": {
653
+ "content": "<|reserved_special_token_73|>",
654
+ "lstrip": false,
655
+ "normalized": false,
656
+ "rstrip": false,
657
+ "single_word": false,
658
+ "special": true
659
+ },
660
+ "128082": {
661
+ "content": "<|reserved_special_token_74|>",
662
+ "lstrip": false,
663
+ "normalized": false,
664
+ "rstrip": false,
665
+ "single_word": false,
666
+ "special": true
667
+ },
668
+ "128083": {
669
+ "content": "<|reserved_special_token_75|>",
670
+ "lstrip": false,
671
+ "normalized": false,
672
+ "rstrip": false,
673
+ "single_word": false,
674
+ "special": true
675
+ },
676
+ "128084": {
677
+ "content": "<|reserved_special_token_76|>",
678
+ "lstrip": false,
679
+ "normalized": false,
680
+ "rstrip": false,
681
+ "single_word": false,
682
+ "special": true
683
+ },
684
+ "128085": {
685
+ "content": "<|reserved_special_token_77|>",
686
+ "lstrip": false,
687
+ "normalized": false,
688
+ "rstrip": false,
689
+ "single_word": false,
690
+ "special": true
691
+ },
692
+ "128086": {
693
+ "content": "<|reserved_special_token_78|>",
694
+ "lstrip": false,
695
+ "normalized": false,
696
+ "rstrip": false,
697
+ "single_word": false,
698
+ "special": true
699
+ },
700
+ "128087": {
701
+ "content": "<|reserved_special_token_79|>",
702
+ "lstrip": false,
703
+ "normalized": false,
704
+ "rstrip": false,
705
+ "single_word": false,
706
+ "special": true
707
+ },
708
+ "128088": {
709
+ "content": "<|reserved_special_token_80|>",
710
+ "lstrip": false,
711
+ "normalized": false,
712
+ "rstrip": false,
713
+ "single_word": false,
714
+ "special": true
715
+ },
716
+ "128089": {
717
+ "content": "<|reserved_special_token_81|>",
718
+ "lstrip": false,
719
+ "normalized": false,
720
+ "rstrip": false,
721
+ "single_word": false,
722
+ "special": true
723
+ },
724
+ "128090": {
725
+ "content": "<|reserved_special_token_82|>",
726
+ "lstrip": false,
727
+ "normalized": false,
728
+ "rstrip": false,
729
+ "single_word": false,
730
+ "special": true
731
+ },
732
+ "128091": {
733
+ "content": "<|reserved_special_token_83|>",
734
+ "lstrip": false,
735
+ "normalized": false,
736
+ "rstrip": false,
737
+ "single_word": false,
738
+ "special": true
739
+ },
740
+ "128092": {
741
+ "content": "<|reserved_special_token_84|>",
742
+ "lstrip": false,
743
+ "normalized": false,
744
+ "rstrip": false,
745
+ "single_word": false,
746
+ "special": true
747
+ },
748
+ "128093": {
749
+ "content": "<|reserved_special_token_85|>",
750
+ "lstrip": false,
751
+ "normalized": false,
752
+ "rstrip": false,
753
+ "single_word": false,
754
+ "special": true
755
+ },
756
+ "128094": {
757
+ "content": "<|reserved_special_token_86|>",
758
+ "lstrip": false,
759
+ "normalized": false,
760
+ "rstrip": false,
761
+ "single_word": false,
762
+ "special": true
763
+ },
764
+ "128095": {
765
+ "content": "<|reserved_special_token_87|>",
766
+ "lstrip": false,
767
+ "normalized": false,
768
+ "rstrip": false,
769
+ "single_word": false,
770
+ "special": true
771
+ },
772
+ "128096": {
773
+ "content": "<|reserved_special_token_88|>",
774
+ "lstrip": false,
775
+ "normalized": false,
776
+ "rstrip": false,
777
+ "single_word": false,
778
+ "special": true
779
+ },
780
+ "128097": {
781
+ "content": "<|reserved_special_token_89|>",
782
+ "lstrip": false,
783
+ "normalized": false,
784
+ "rstrip": false,
785
+ "single_word": false,
786
+ "special": true
787
+ },
788
+ "128098": {
789
+ "content": "<|reserved_special_token_90|>",
790
+ "lstrip": false,
791
+ "normalized": false,
792
+ "rstrip": false,
793
+ "single_word": false,
794
+ "special": true
795
+ },
796
+ "128099": {
797
+ "content": "<|reserved_special_token_91|>",
798
+ "lstrip": false,
799
+ "normalized": false,
800
+ "rstrip": false,
801
+ "single_word": false,
802
+ "special": true
803
+ },
804
+ "128100": {
805
+ "content": "<|reserved_special_token_92|>",
806
+ "lstrip": false,
807
+ "normalized": false,
808
+ "rstrip": false,
809
+ "single_word": false,
810
+ "special": true
811
+ },
812
+ "128101": {
813
+ "content": "<|reserved_special_token_93|>",
814
+ "lstrip": false,
815
+ "normalized": false,
816
+ "rstrip": false,
817
+ "single_word": false,
818
+ "special": true
819
+ },
820
+ "128102": {
821
+ "content": "<|reserved_special_token_94|>",
822
+ "lstrip": false,
823
+ "normalized": false,
824
+ "rstrip": false,
825
+ "single_word": false,
826
+ "special": true
827
+ },
828
+ "128103": {
829
+ "content": "<|reserved_special_token_95|>",
830
+ "lstrip": false,
831
+ "normalized": false,
832
+ "rstrip": false,
833
+ "single_word": false,
834
+ "special": true
835
+ },
836
+ "128104": {
837
+ "content": "<|reserved_special_token_96|>",
838
+ "lstrip": false,
839
+ "normalized": false,
840
+ "rstrip": false,
841
+ "single_word": false,
842
+ "special": true
843
+ },
844
+ "128105": {
845
+ "content": "<|reserved_special_token_97|>",
846
+ "lstrip": false,
847
+ "normalized": false,
848
+ "rstrip": false,
849
+ "single_word": false,
850
+ "special": true
851
+ },
852
+ "128106": {
853
+ "content": "<|reserved_special_token_98|>",
854
+ "lstrip": false,
855
+ "normalized": false,
856
+ "rstrip": false,
857
+ "single_word": false,
858
+ "special": true
859
+ },
860
+ "128107": {
861
+ "content": "<|reserved_special_token_99|>",
862
+ "lstrip": false,
863
+ "normalized": false,
864
+ "rstrip": false,
865
+ "single_word": false,
866
+ "special": true
867
+ },
868
+ "128108": {
869
+ "content": "<|reserved_special_token_100|>",
870
+ "lstrip": false,
871
+ "normalized": false,
872
+ "rstrip": false,
873
+ "single_word": false,
874
+ "special": true
875
+ },
876
+ "128109": {
877
+ "content": "<|reserved_special_token_101|>",
878
+ "lstrip": false,
879
+ "normalized": false,
880
+ "rstrip": false,
881
+ "single_word": false,
882
+ "special": true
883
+ },
884
+ "128110": {
885
+ "content": "<|reserved_special_token_102|>",
886
+ "lstrip": false,
887
+ "normalized": false,
888
+ "rstrip": false,
889
+ "single_word": false,
890
+ "special": true
891
+ },
892
+ "128111": {
893
+ "content": "<|reserved_special_token_103|>",
894
+ "lstrip": false,
895
+ "normalized": false,
896
+ "rstrip": false,
897
+ "single_word": false,
898
+ "special": true
899
+ },
900
+ "128112": {
901
+ "content": "<|reserved_special_token_104|>",
902
+ "lstrip": false,
903
+ "normalized": false,
904
+ "rstrip": false,
905
+ "single_word": false,
906
+ "special": true
907
+ },
908
+ "128113": {
909
+ "content": "<|reserved_special_token_105|>",
910
+ "lstrip": false,
911
+ "normalized": false,
912
+ "rstrip": false,
913
+ "single_word": false,
914
+ "special": true
915
+ },
916
+ "128114": {
917
+ "content": "<|reserved_special_token_106|>",
918
+ "lstrip": false,
919
+ "normalized": false,
920
+ "rstrip": false,
921
+ "single_word": false,
922
+ "special": true
923
+ },
924
+ "128115": {
925
+ "content": "<|reserved_special_token_107|>",
926
+ "lstrip": false,
927
+ "normalized": false,
928
+ "rstrip": false,
929
+ "single_word": false,
930
+ "special": true
931
+ },
932
+ "128116": {
933
+ "content": "<|reserved_special_token_108|>",
934
+ "lstrip": false,
935
+ "normalized": false,
936
+ "rstrip": false,
937
+ "single_word": false,
938
+ "special": true
939
+ },
940
+ "128117": {
941
+ "content": "<|reserved_special_token_109|>",
942
+ "lstrip": false,
943
+ "normalized": false,
944
+ "rstrip": false,
945
+ "single_word": false,
946
+ "special": true
947
+ },
948
+ "128118": {
949
+ "content": "<|reserved_special_token_110|>",
950
+ "lstrip": false,
951
+ "normalized": false,
952
+ "rstrip": false,
953
+ "single_word": false,
954
+ "special": true
955
+ },
956
+ "128119": {
957
+ "content": "<|reserved_special_token_111|>",
958
+ "lstrip": false,
959
+ "normalized": false,
960
+ "rstrip": false,
961
+ "single_word": false,
962
+ "special": true
963
+ },
964
+ "128120": {
965
+ "content": "<|reserved_special_token_112|>",
966
+ "lstrip": false,
967
+ "normalized": false,
968
+ "rstrip": false,
969
+ "single_word": false,
970
+ "special": true
971
+ },
972
+ "128121": {
973
+ "content": "<|reserved_special_token_113|>",
974
+ "lstrip": false,
975
+ "normalized": false,
976
+ "rstrip": false,
977
+ "single_word": false,
978
+ "special": true
979
+ },
980
+ "128122": {
981
+ "content": "<|reserved_special_token_114|>",
982
+ "lstrip": false,
983
+ "normalized": false,
984
+ "rstrip": false,
985
+ "single_word": false,
986
+ "special": true
987
+ },
988
+ "128123": {
989
+ "content": "<|reserved_special_token_115|>",
990
+ "lstrip": false,
991
+ "normalized": false,
992
+ "rstrip": false,
993
+ "single_word": false,
994
+ "special": true
995
+ },
996
+ "128124": {
997
+ "content": "<|reserved_special_token_116|>",
998
+ "lstrip": false,
999
+ "normalized": false,
1000
+ "rstrip": false,
1001
+ "single_word": false,
1002
+ "special": true
1003
+ },
1004
+ "128125": {
1005
+ "content": "<|reserved_special_token_117|>",
1006
+ "lstrip": false,
1007
+ "normalized": false,
1008
+ "rstrip": false,
1009
+ "single_word": false,
1010
+ "special": true
1011
+ },
1012
+ "128126": {
1013
+ "content": "<|reserved_special_token_118|>",
1014
+ "lstrip": false,
1015
+ "normalized": false,
1016
+ "rstrip": false,
1017
+ "single_word": false,
1018
+ "special": true
1019
+ },
1020
+ "128127": {
1021
+ "content": "<|reserved_special_token_119|>",
1022
+ "lstrip": false,
1023
+ "normalized": false,
1024
+ "rstrip": false,
1025
+ "single_word": false,
1026
+ "special": true
1027
+ },
1028
+ "128128": {
1029
+ "content": "<|reserved_special_token_120|>",
1030
+ "lstrip": false,
1031
+ "normalized": false,
1032
+ "rstrip": false,
1033
+ "single_word": false,
1034
+ "special": true
1035
+ },
1036
+ "128129": {
1037
+ "content": "<|reserved_special_token_121|>",
1038
+ "lstrip": false,
1039
+ "normalized": false,
1040
+ "rstrip": false,
1041
+ "single_word": false,
1042
+ "special": true
1043
+ },
1044
+ "128130": {
1045
+ "content": "<|reserved_special_token_122|>",
1046
+ "lstrip": false,
1047
+ "normalized": false,
1048
+ "rstrip": false,
1049
+ "single_word": false,
1050
+ "special": true
1051
+ },
1052
+ "128131": {
1053
+ "content": "<|reserved_special_token_123|>",
1054
+ "lstrip": false,
1055
+ "normalized": false,
1056
+ "rstrip": false,
1057
+ "single_word": false,
1058
+ "special": true
1059
+ },
1060
+ "128132": {
1061
+ "content": "<|reserved_special_token_124|>",
1062
+ "lstrip": false,
1063
+ "normalized": false,
1064
+ "rstrip": false,
1065
+ "single_word": false,
1066
+ "special": true
1067
+ },
1068
+ "128133": {
1069
+ "content": "<|reserved_special_token_125|>",
1070
+ "lstrip": false,
1071
+ "normalized": false,
1072
+ "rstrip": false,
1073
+ "single_word": false,
1074
+ "special": true
1075
+ },
1076
+ "128134": {
1077
+ "content": "<|reserved_special_token_126|>",
1078
+ "lstrip": false,
1079
+ "normalized": false,
1080
+ "rstrip": false,
1081
+ "single_word": false,
1082
+ "special": true
1083
+ },
1084
+ "128135": {
1085
+ "content": "<|reserved_special_token_127|>",
1086
+ "lstrip": false,
1087
+ "normalized": false,
1088
+ "rstrip": false,
1089
+ "single_word": false,
1090
+ "special": true
1091
+ },
1092
+ "128136": {
1093
+ "content": "<|reserved_special_token_128|>",
1094
+ "lstrip": false,
1095
+ "normalized": false,
1096
+ "rstrip": false,
1097
+ "single_word": false,
1098
+ "special": true
1099
+ },
1100
+ "128137": {
1101
+ "content": "<|reserved_special_token_129|>",
1102
+ "lstrip": false,
1103
+ "normalized": false,
1104
+ "rstrip": false,
1105
+ "single_word": false,
1106
+ "special": true
1107
+ },
1108
+ "128138": {
1109
+ "content": "<|reserved_special_token_130|>",
1110
+ "lstrip": false,
1111
+ "normalized": false,
1112
+ "rstrip": false,
1113
+ "single_word": false,
1114
+ "special": true
1115
+ },
1116
+ "128139": {
1117
+ "content": "<|reserved_special_token_131|>",
1118
+ "lstrip": false,
1119
+ "normalized": false,
1120
+ "rstrip": false,
1121
+ "single_word": false,
1122
+ "special": true
1123
+ },
1124
+ "128140": {
1125
+ "content": "<|reserved_special_token_132|>",
1126
+ "lstrip": false,
1127
+ "normalized": false,
1128
+ "rstrip": false,
1129
+ "single_word": false,
1130
+ "special": true
1131
+ },
1132
+ "128141": {
1133
+ "content": "<|reserved_special_token_133|>",
1134
+ "lstrip": false,
1135
+ "normalized": false,
1136
+ "rstrip": false,
1137
+ "single_word": false,
1138
+ "special": true
1139
+ },
1140
+ "128142": {
1141
+ "content": "<|reserved_special_token_134|>",
1142
+ "lstrip": false,
1143
+ "normalized": false,
1144
+ "rstrip": false,
1145
+ "single_word": false,
1146
+ "special": true
1147
+ },
1148
+ "128143": {
1149
+ "content": "<|reserved_special_token_135|>",
1150
+ "lstrip": false,
1151
+ "normalized": false,
1152
+ "rstrip": false,
1153
+ "single_word": false,
1154
+ "special": true
1155
+ },
1156
+ "128144": {
1157
+ "content": "<|reserved_special_token_136|>",
1158
+ "lstrip": false,
1159
+ "normalized": false,
1160
+ "rstrip": false,
1161
+ "single_word": false,
1162
+ "special": true
1163
+ },
1164
+ "128145": {
1165
+ "content": "<|reserved_special_token_137|>",
1166
+ "lstrip": false,
1167
+ "normalized": false,
1168
+ "rstrip": false,
1169
+ "single_word": false,
1170
+ "special": true
1171
+ },
1172
+ "128146": {
1173
+ "content": "<|reserved_special_token_138|>",
1174
+ "lstrip": false,
1175
+ "normalized": false,
1176
+ "rstrip": false,
1177
+ "single_word": false,
1178
+ "special": true
1179
+ },
1180
+ "128147": {
1181
+ "content": "<|reserved_special_token_139|>",
1182
+ "lstrip": false,
1183
+ "normalized": false,
1184
+ "rstrip": false,
1185
+ "single_word": false,
1186
+ "special": true
1187
+ },
1188
+ "128148": {
1189
+ "content": "<|reserved_special_token_140|>",
1190
+ "lstrip": false,
1191
+ "normalized": false,
1192
+ "rstrip": false,
1193
+ "single_word": false,
1194
+ "special": true
1195
+ },
1196
+ "128149": {
1197
+ "content": "<|reserved_special_token_141|>",
1198
+ "lstrip": false,
1199
+ "normalized": false,
1200
+ "rstrip": false,
1201
+ "single_word": false,
1202
+ "special": true
1203
+ },
1204
+ "128150": {
1205
+ "content": "<|reserved_special_token_142|>",
1206
+ "lstrip": false,
1207
+ "normalized": false,
1208
+ "rstrip": false,
1209
+ "single_word": false,
1210
+ "special": true
1211
+ },
1212
+ "128151": {
1213
+ "content": "<|reserved_special_token_143|>",
1214
+ "lstrip": false,
1215
+ "normalized": false,
1216
+ "rstrip": false,
1217
+ "single_word": false,
1218
+ "special": true
1219
+ },
1220
+ "128152": {
1221
+ "content": "<|reserved_special_token_144|>",
1222
+ "lstrip": false,
1223
+ "normalized": false,
1224
+ "rstrip": false,
1225
+ "single_word": false,
1226
+ "special": true
1227
+ },
1228
+ "128153": {
1229
+ "content": "<|reserved_special_token_145|>",
1230
+ "lstrip": false,
1231
+ "normalized": false,
1232
+ "rstrip": false,
1233
+ "single_word": false,
1234
+ "special": true
1235
+ },
1236
+ "128154": {
1237
+ "content": "<|reserved_special_token_146|>",
1238
+ "lstrip": false,
1239
+ "normalized": false,
1240
+ "rstrip": false,
1241
+ "single_word": false,
1242
+ "special": true
1243
+ },
1244
+ "128155": {
1245
+ "content": "<|reserved_special_token_147|>",
1246
+ "lstrip": false,
1247
+ "normalized": false,
1248
+ "rstrip": false,
1249
+ "single_word": false,
1250
+ "special": true
1251
+ },
1252
+ "128156": {
1253
+ "content": "<|reserved_special_token_148|>",
1254
+ "lstrip": false,
1255
+ "normalized": false,
1256
+ "rstrip": false,
1257
+ "single_word": false,
1258
+ "special": true
1259
+ },
1260
+ "128157": {
1261
+ "content": "<|reserved_special_token_149|>",
1262
+ "lstrip": false,
1263
+ "normalized": false,
1264
+ "rstrip": false,
1265
+ "single_word": false,
1266
+ "special": true
1267
+ },
1268
+ "128158": {
1269
+ "content": "<|reserved_special_token_150|>",
1270
+ "lstrip": false,
1271
+ "normalized": false,
1272
+ "rstrip": false,
1273
+ "single_word": false,
1274
+ "special": true
1275
+ },
1276
+ "128159": {
1277
+ "content": "<|reserved_special_token_151|>",
1278
+ "lstrip": false,
1279
+ "normalized": false,
1280
+ "rstrip": false,
1281
+ "single_word": false,
1282
+ "special": true
1283
+ },
1284
+ "128160": {
1285
+ "content": "<|reserved_special_token_152|>",
1286
+ "lstrip": false,
1287
+ "normalized": false,
1288
+ "rstrip": false,
1289
+ "single_word": false,
1290
+ "special": true
1291
+ },
1292
+ "128161": {
1293
+ "content": "<|reserved_special_token_153|>",
1294
+ "lstrip": false,
1295
+ "normalized": false,
1296
+ "rstrip": false,
1297
+ "single_word": false,
1298
+ "special": true
1299
+ },
1300
+ "128162": {
1301
+ "content": "<|reserved_special_token_154|>",
1302
+ "lstrip": false,
1303
+ "normalized": false,
1304
+ "rstrip": false,
1305
+ "single_word": false,
1306
+ "special": true
1307
+ },
1308
+ "128163": {
1309
+ "content": "<|reserved_special_token_155|>",
1310
+ "lstrip": false,
1311
+ "normalized": false,
1312
+ "rstrip": false,
1313
+ "single_word": false,
1314
+ "special": true
1315
+ },
1316
+ "128164": {
1317
+ "content": "<|reserved_special_token_156|>",
1318
+ "lstrip": false,
1319
+ "normalized": false,
1320
+ "rstrip": false,
1321
+ "single_word": false,
1322
+ "special": true
1323
+ },
1324
+ "128165": {
1325
+ "content": "<|reserved_special_token_157|>",
1326
+ "lstrip": false,
1327
+ "normalized": false,
1328
+ "rstrip": false,
1329
+ "single_word": false,
1330
+ "special": true
1331
+ },
1332
+ "128166": {
1333
+ "content": "<|reserved_special_token_158|>",
1334
+ "lstrip": false,
1335
+ "normalized": false,
1336
+ "rstrip": false,
1337
+ "single_word": false,
1338
+ "special": true
1339
+ },
1340
+ "128167": {
1341
+ "content": "<|reserved_special_token_159|>",
1342
+ "lstrip": false,
1343
+ "normalized": false,
1344
+ "rstrip": false,
1345
+ "single_word": false,
1346
+ "special": true
1347
+ },
1348
+ "128168": {
1349
+ "content": "<|reserved_special_token_160|>",
1350
+ "lstrip": false,
1351
+ "normalized": false,
1352
+ "rstrip": false,
1353
+ "single_word": false,
1354
+ "special": true
1355
+ },
1356
+ "128169": {
1357
+ "content": "<|reserved_special_token_161|>",
1358
+ "lstrip": false,
1359
+ "normalized": false,
1360
+ "rstrip": false,
1361
+ "single_word": false,
1362
+ "special": true
1363
+ },
1364
+ "128170": {
1365
+ "content": "<|reserved_special_token_162|>",
1366
+ "lstrip": false,
1367
+ "normalized": false,
1368
+ "rstrip": false,
1369
+ "single_word": false,
1370
+ "special": true
1371
+ },
1372
+ "128171": {
1373
+ "content": "<|reserved_special_token_163|>",
1374
+ "lstrip": false,
1375
+ "normalized": false,
1376
+ "rstrip": false,
1377
+ "single_word": false,
1378
+ "special": true
1379
+ },
1380
+ "128172": {
1381
+ "content": "<|reserved_special_token_164|>",
1382
+ "lstrip": false,
1383
+ "normalized": false,
1384
+ "rstrip": false,
1385
+ "single_word": false,
1386
+ "special": true
1387
+ },
1388
+ "128173": {
1389
+ "content": "<|reserved_special_token_165|>",
1390
+ "lstrip": false,
1391
+ "normalized": false,
1392
+ "rstrip": false,
1393
+ "single_word": false,
1394
+ "special": true
1395
+ },
1396
+ "128174": {
1397
+ "content": "<|reserved_special_token_166|>",
1398
+ "lstrip": false,
1399
+ "normalized": false,
1400
+ "rstrip": false,
1401
+ "single_word": false,
1402
+ "special": true
1403
+ },
1404
+ "128175": {
1405
+ "content": "<|reserved_special_token_167|>",
1406
+ "lstrip": false,
1407
+ "normalized": false,
1408
+ "rstrip": false,
1409
+ "single_word": false,
1410
+ "special": true
1411
+ },
1412
+ "128176": {
1413
+ "content": "<|reserved_special_token_168|>",
1414
+ "lstrip": false,
1415
+ "normalized": false,
1416
+ "rstrip": false,
1417
+ "single_word": false,
1418
+ "special": true
1419
+ },
1420
+ "128177": {
1421
+ "content": "<|reserved_special_token_169|>",
1422
+ "lstrip": false,
1423
+ "normalized": false,
1424
+ "rstrip": false,
1425
+ "single_word": false,
1426
+ "special": true
1427
+ },
1428
+ "128178": {
1429
+ "content": "<|reserved_special_token_170|>",
1430
+ "lstrip": false,
1431
+ "normalized": false,
1432
+ "rstrip": false,
1433
+ "single_word": false,
1434
+ "special": true
1435
+ },
1436
+ "128179": {
1437
+ "content": "<|reserved_special_token_171|>",
1438
+ "lstrip": false,
1439
+ "normalized": false,
1440
+ "rstrip": false,
1441
+ "single_word": false,
1442
+ "special": true
1443
+ },
1444
+ "128180": {
1445
+ "content": "<|reserved_special_token_172|>",
1446
+ "lstrip": false,
1447
+ "normalized": false,
1448
+ "rstrip": false,
1449
+ "single_word": false,
1450
+ "special": true
1451
+ },
1452
+ "128181": {
1453
+ "content": "<|reserved_special_token_173|>",
1454
+ "lstrip": false,
1455
+ "normalized": false,
1456
+ "rstrip": false,
1457
+ "single_word": false,
1458
+ "special": true
1459
+ },
1460
+ "128182": {
1461
+ "content": "<|reserved_special_token_174|>",
1462
+ "lstrip": false,
1463
+ "normalized": false,
1464
+ "rstrip": false,
1465
+ "single_word": false,
1466
+ "special": true
1467
+ },
1468
+ "128183": {
1469
+ "content": "<|reserved_special_token_175|>",
1470
+ "lstrip": false,
1471
+ "normalized": false,
1472
+ "rstrip": false,
1473
+ "single_word": false,
1474
+ "special": true
1475
+ },
1476
+ "128184": {
1477
+ "content": "<|reserved_special_token_176|>",
1478
+ "lstrip": false,
1479
+ "normalized": false,
1480
+ "rstrip": false,
1481
+ "single_word": false,
1482
+ "special": true
1483
+ },
1484
+ "128185": {
1485
+ "content": "<|reserved_special_token_177|>",
1486
+ "lstrip": false,
1487
+ "normalized": false,
1488
+ "rstrip": false,
1489
+ "single_word": false,
1490
+ "special": true
1491
+ },
1492
+ "128186": {
1493
+ "content": "<|reserved_special_token_178|>",
1494
+ "lstrip": false,
1495
+ "normalized": false,
1496
+ "rstrip": false,
1497
+ "single_word": false,
1498
+ "special": true
1499
+ },
1500
+ "128187": {
1501
+ "content": "<|reserved_special_token_179|>",
1502
+ "lstrip": false,
1503
+ "normalized": false,
1504
+ "rstrip": false,
1505
+ "single_word": false,
1506
+ "special": true
1507
+ },
1508
+ "128188": {
1509
+ "content": "<|reserved_special_token_180|>",
1510
+ "lstrip": false,
1511
+ "normalized": false,
1512
+ "rstrip": false,
1513
+ "single_word": false,
1514
+ "special": true
1515
+ },
1516
+ "128189": {
1517
+ "content": "<|reserved_special_token_181|>",
1518
+ "lstrip": false,
1519
+ "normalized": false,
1520
+ "rstrip": false,
1521
+ "single_word": false,
1522
+ "special": true
1523
+ },
1524
+ "128190": {
1525
+ "content": "<|reserved_special_token_182|>",
1526
+ "lstrip": false,
1527
+ "normalized": false,
1528
+ "rstrip": false,
1529
+ "single_word": false,
1530
+ "special": true
1531
+ },
1532
+ "128191": {
1533
+ "content": "<|reserved_special_token_183|>",
1534
+ "lstrip": false,
1535
+ "normalized": false,
1536
+ "rstrip": false,
1537
+ "single_word": false,
1538
+ "special": true
1539
+ },
1540
+ "128192": {
1541
+ "content": "<|reserved_special_token_184|>",
1542
+ "lstrip": false,
1543
+ "normalized": false,
1544
+ "rstrip": false,
1545
+ "single_word": false,
1546
+ "special": true
1547
+ },
1548
+ "128193": {
1549
+ "content": "<|reserved_special_token_185|>",
1550
+ "lstrip": false,
1551
+ "normalized": false,
1552
+ "rstrip": false,
1553
+ "single_word": false,
1554
+ "special": true
1555
+ },
1556
+ "128194": {
1557
+ "content": "<|reserved_special_token_186|>",
1558
+ "lstrip": false,
1559
+ "normalized": false,
1560
+ "rstrip": false,
1561
+ "single_word": false,
1562
+ "special": true
1563
+ },
1564
+ "128195": {
1565
+ "content": "<|reserved_special_token_187|>",
1566
+ "lstrip": false,
1567
+ "normalized": false,
1568
+ "rstrip": false,
1569
+ "single_word": false,
1570
+ "special": true
1571
+ },
1572
+ "128196": {
1573
+ "content": "<|reserved_special_token_188|>",
1574
+ "lstrip": false,
1575
+ "normalized": false,
1576
+ "rstrip": false,
1577
+ "single_word": false,
1578
+ "special": true
1579
+ },
1580
+ "128197": {
1581
+ "content": "<|reserved_special_token_189|>",
1582
+ "lstrip": false,
1583
+ "normalized": false,
1584
+ "rstrip": false,
1585
+ "single_word": false,
1586
+ "special": true
1587
+ },
1588
+ "128198": {
1589
+ "content": "<|reserved_special_token_190|>",
1590
+ "lstrip": false,
1591
+ "normalized": false,
1592
+ "rstrip": false,
1593
+ "single_word": false,
1594
+ "special": true
1595
+ },
1596
+ "128199": {
1597
+ "content": "<|reserved_special_token_191|>",
1598
+ "lstrip": false,
1599
+ "normalized": false,
1600
+ "rstrip": false,
1601
+ "single_word": false,
1602
+ "special": true
1603
+ },
1604
+ "128200": {
1605
+ "content": "<|reserved_special_token_192|>",
1606
+ "lstrip": false,
1607
+ "normalized": false,
1608
+ "rstrip": false,
1609
+ "single_word": false,
1610
+ "special": true
1611
+ },
1612
+ "128201": {
1613
+ "content": "<|reserved_special_token_193|>",
1614
+ "lstrip": false,
1615
+ "normalized": false,
1616
+ "rstrip": false,
1617
+ "single_word": false,
1618
+ "special": true
1619
+ },
1620
+ "128202": {
1621
+ "content": "<|reserved_special_token_194|>",
1622
+ "lstrip": false,
1623
+ "normalized": false,
1624
+ "rstrip": false,
1625
+ "single_word": false,
1626
+ "special": true
1627
+ },
1628
+ "128203": {
1629
+ "content": "<|reserved_special_token_195|>",
1630
+ "lstrip": false,
1631
+ "normalized": false,
1632
+ "rstrip": false,
1633
+ "single_word": false,
1634
+ "special": true
1635
+ },
1636
+ "128204": {
1637
+ "content": "<|reserved_special_token_196|>",
1638
+ "lstrip": false,
1639
+ "normalized": false,
1640
+ "rstrip": false,
1641
+ "single_word": false,
1642
+ "special": true
1643
+ },
1644
+ "128205": {
1645
+ "content": "<|reserved_special_token_197|>",
1646
+ "lstrip": false,
1647
+ "normalized": false,
1648
+ "rstrip": false,
1649
+ "single_word": false,
1650
+ "special": true
1651
+ },
1652
+ "128206": {
1653
+ "content": "<|reserved_special_token_198|>",
1654
+ "lstrip": false,
1655
+ "normalized": false,
1656
+ "rstrip": false,
1657
+ "single_word": false,
1658
+ "special": true
1659
+ },
1660
+ "128207": {
1661
+ "content": "<|reserved_special_token_199|>",
1662
+ "lstrip": false,
1663
+ "normalized": false,
1664
+ "rstrip": false,
1665
+ "single_word": false,
1666
+ "special": true
1667
+ },
1668
+ "128208": {
1669
+ "content": "<|reserved_special_token_200|>",
1670
+ "lstrip": false,
1671
+ "normalized": false,
1672
+ "rstrip": false,
1673
+ "single_word": false,
1674
+ "special": true
1675
+ },
1676
+ "128209": {
1677
+ "content": "<|reserved_special_token_201|>",
1678
+ "lstrip": false,
1679
+ "normalized": false,
1680
+ "rstrip": false,
1681
+ "single_word": false,
1682
+ "special": true
1683
+ },
1684
+ "128210": {
1685
+ "content": "<|reserved_special_token_202|>",
1686
+ "lstrip": false,
1687
+ "normalized": false,
1688
+ "rstrip": false,
1689
+ "single_word": false,
1690
+ "special": true
1691
+ },
1692
+ "128211": {
1693
+ "content": "<|reserved_special_token_203|>",
1694
+ "lstrip": false,
1695
+ "normalized": false,
1696
+ "rstrip": false,
1697
+ "single_word": false,
1698
+ "special": true
1699
+ },
1700
+ "128212": {
1701
+ "content": "<|reserved_special_token_204|>",
1702
+ "lstrip": false,
1703
+ "normalized": false,
1704
+ "rstrip": false,
1705
+ "single_word": false,
1706
+ "special": true
1707
+ },
1708
+ "128213": {
1709
+ "content": "<|reserved_special_token_205|>",
1710
+ "lstrip": false,
1711
+ "normalized": false,
1712
+ "rstrip": false,
1713
+ "single_word": false,
1714
+ "special": true
1715
+ },
1716
+ "128214": {
1717
+ "content": "<|reserved_special_token_206|>",
1718
+ "lstrip": false,
1719
+ "normalized": false,
1720
+ "rstrip": false,
1721
+ "single_word": false,
1722
+ "special": true
1723
+ },
1724
+ "128215": {
1725
+ "content": "<|reserved_special_token_207|>",
1726
+ "lstrip": false,
1727
+ "normalized": false,
1728
+ "rstrip": false,
1729
+ "single_word": false,
1730
+ "special": true
1731
+ },
1732
+ "128216": {
1733
+ "content": "<|reserved_special_token_208|>",
1734
+ "lstrip": false,
1735
+ "normalized": false,
1736
+ "rstrip": false,
1737
+ "single_word": false,
1738
+ "special": true
1739
+ },
1740
+ "128217": {
1741
+ "content": "<|reserved_special_token_209|>",
1742
+ "lstrip": false,
1743
+ "normalized": false,
1744
+ "rstrip": false,
1745
+ "single_word": false,
1746
+ "special": true
1747
+ },
1748
+ "128218": {
1749
+ "content": "<|reserved_special_token_210|>",
1750
+ "lstrip": false,
1751
+ "normalized": false,
1752
+ "rstrip": false,
1753
+ "single_word": false,
1754
+ "special": true
1755
+ },
1756
+ "128219": {
1757
+ "content": "<|reserved_special_token_211|>",
1758
+ "lstrip": false,
1759
+ "normalized": false,
1760
+ "rstrip": false,
1761
+ "single_word": false,
1762
+ "special": true
1763
+ },
1764
+ "128220": {
1765
+ "content": "<|reserved_special_token_212|>",
1766
+ "lstrip": false,
1767
+ "normalized": false,
1768
+ "rstrip": false,
1769
+ "single_word": false,
1770
+ "special": true
1771
+ },
1772
+ "128221": {
1773
+ "content": "<|reserved_special_token_213|>",
1774
+ "lstrip": false,
1775
+ "normalized": false,
1776
+ "rstrip": false,
1777
+ "single_word": false,
1778
+ "special": true
1779
+ },
1780
+ "128222": {
1781
+ "content": "<|reserved_special_token_214|>",
1782
+ "lstrip": false,
1783
+ "normalized": false,
1784
+ "rstrip": false,
1785
+ "single_word": false,
1786
+ "special": true
1787
+ },
1788
+ "128223": {
1789
+ "content": "<|reserved_special_token_215|>",
1790
+ "lstrip": false,
1791
+ "normalized": false,
1792
+ "rstrip": false,
1793
+ "single_word": false,
1794
+ "special": true
1795
+ },
1796
+ "128224": {
1797
+ "content": "<|reserved_special_token_216|>",
1798
+ "lstrip": false,
1799
+ "normalized": false,
1800
+ "rstrip": false,
1801
+ "single_word": false,
1802
+ "special": true
1803
+ },
1804
+ "128225": {
1805
+ "content": "<|reserved_special_token_217|>",
1806
+ "lstrip": false,
1807
+ "normalized": false,
1808
+ "rstrip": false,
1809
+ "single_word": false,
1810
+ "special": true
1811
+ },
1812
+ "128226": {
1813
+ "content": "<|reserved_special_token_218|>",
1814
+ "lstrip": false,
1815
+ "normalized": false,
1816
+ "rstrip": false,
1817
+ "single_word": false,
1818
+ "special": true
1819
+ },
1820
+ "128227": {
1821
+ "content": "<|reserved_special_token_219|>",
1822
+ "lstrip": false,
1823
+ "normalized": false,
1824
+ "rstrip": false,
1825
+ "single_word": false,
1826
+ "special": true
1827
+ },
1828
+ "128228": {
1829
+ "content": "<|reserved_special_token_220|>",
1830
+ "lstrip": false,
1831
+ "normalized": false,
1832
+ "rstrip": false,
1833
+ "single_word": false,
1834
+ "special": true
1835
+ },
1836
+ "128229": {
1837
+ "content": "<|reserved_special_token_221|>",
1838
+ "lstrip": false,
1839
+ "normalized": false,
1840
+ "rstrip": false,
1841
+ "single_word": false,
1842
+ "special": true
1843
+ },
1844
+ "128230": {
1845
+ "content": "<|reserved_special_token_222|>",
1846
+ "lstrip": false,
1847
+ "normalized": false,
1848
+ "rstrip": false,
1849
+ "single_word": false,
1850
+ "special": true
1851
+ },
1852
+ "128231": {
1853
+ "content": "<|reserved_special_token_223|>",
1854
+ "lstrip": false,
1855
+ "normalized": false,
1856
+ "rstrip": false,
1857
+ "single_word": false,
1858
+ "special": true
1859
+ },
1860
+ "128232": {
1861
+ "content": "<|reserved_special_token_224|>",
1862
+ "lstrip": false,
1863
+ "normalized": false,
1864
+ "rstrip": false,
1865
+ "single_word": false,
1866
+ "special": true
1867
+ },
1868
+ "128233": {
1869
+ "content": "<|reserved_special_token_225|>",
1870
+ "lstrip": false,
1871
+ "normalized": false,
1872
+ "rstrip": false,
1873
+ "single_word": false,
1874
+ "special": true
1875
+ },
1876
+ "128234": {
1877
+ "content": "<|reserved_special_token_226|>",
1878
+ "lstrip": false,
1879
+ "normalized": false,
1880
+ "rstrip": false,
1881
+ "single_word": false,
1882
+ "special": true
1883
+ },
1884
+ "128235": {
1885
+ "content": "<|reserved_special_token_227|>",
1886
+ "lstrip": false,
1887
+ "normalized": false,
1888
+ "rstrip": false,
1889
+ "single_word": false,
1890
+ "special": true
1891
+ },
1892
+ "128236": {
1893
+ "content": "<|reserved_special_token_228|>",
1894
+ "lstrip": false,
1895
+ "normalized": false,
1896
+ "rstrip": false,
1897
+ "single_word": false,
1898
+ "special": true
1899
+ },
1900
+ "128237": {
1901
+ "content": "<|reserved_special_token_229|>",
1902
+ "lstrip": false,
1903
+ "normalized": false,
1904
+ "rstrip": false,
1905
+ "single_word": false,
1906
+ "special": true
1907
+ },
1908
+ "128238": {
1909
+ "content": "<|reserved_special_token_230|>",
1910
+ "lstrip": false,
1911
+ "normalized": false,
1912
+ "rstrip": false,
1913
+ "single_word": false,
1914
+ "special": true
1915
+ },
1916
+ "128239": {
1917
+ "content": "<|reserved_special_token_231|>",
1918
+ "lstrip": false,
1919
+ "normalized": false,
1920
+ "rstrip": false,
1921
+ "single_word": false,
1922
+ "special": true
1923
+ },
1924
+ "128240": {
1925
+ "content": "<|reserved_special_token_232|>",
1926
+ "lstrip": false,
1927
+ "normalized": false,
1928
+ "rstrip": false,
1929
+ "single_word": false,
1930
+ "special": true
1931
+ },
1932
+ "128241": {
1933
+ "content": "<|reserved_special_token_233|>",
1934
+ "lstrip": false,
1935
+ "normalized": false,
1936
+ "rstrip": false,
1937
+ "single_word": false,
1938
+ "special": true
1939
+ },
1940
+ "128242": {
1941
+ "content": "<|reserved_special_token_234|>",
1942
+ "lstrip": false,
1943
+ "normalized": false,
1944
+ "rstrip": false,
1945
+ "single_word": false,
1946
+ "special": true
1947
+ },
1948
+ "128243": {
1949
+ "content": "<|reserved_special_token_235|>",
1950
+ "lstrip": false,
1951
+ "normalized": false,
1952
+ "rstrip": false,
1953
+ "single_word": false,
1954
+ "special": true
1955
+ },
1956
+ "128244": {
1957
+ "content": "<|reserved_special_token_236|>",
1958
+ "lstrip": false,
1959
+ "normalized": false,
1960
+ "rstrip": false,
1961
+ "single_word": false,
1962
+ "special": true
1963
+ },
1964
+ "128245": {
1965
+ "content": "<|reserved_special_token_237|>",
1966
+ "lstrip": false,
1967
+ "normalized": false,
1968
+ "rstrip": false,
1969
+ "single_word": false,
1970
+ "special": true
1971
+ },
1972
+ "128246": {
1973
+ "content": "<|reserved_special_token_238|>",
1974
+ "lstrip": false,
1975
+ "normalized": false,
1976
+ "rstrip": false,
1977
+ "single_word": false,
1978
+ "special": true
1979
+ },
1980
+ "128247": {
1981
+ "content": "<|reserved_special_token_239|>",
1982
+ "lstrip": false,
1983
+ "normalized": false,
1984
+ "rstrip": false,
1985
+ "single_word": false,
1986
+ "special": true
1987
+ },
1988
+ "128248": {
1989
+ "content": "<|reserved_special_token_240|>",
1990
+ "lstrip": false,
1991
+ "normalized": false,
1992
+ "rstrip": false,
1993
+ "single_word": false,
1994
+ "special": true
1995
+ },
1996
+ "128249": {
1997
+ "content": "<|reserved_special_token_241|>",
1998
+ "lstrip": false,
1999
+ "normalized": false,
2000
+ "rstrip": false,
2001
+ "single_word": false,
2002
+ "special": true
2003
+ },
2004
+ "128250": {
2005
+ "content": "<|reserved_special_token_242|>",
2006
+ "lstrip": false,
2007
+ "normalized": false,
2008
+ "rstrip": false,
2009
+ "single_word": false,
2010
+ "special": true
2011
+ },
2012
+ "128251": {
2013
+ "content": "<|reserved_special_token_243|>",
2014
+ "lstrip": false,
2015
+ "normalized": false,
2016
+ "rstrip": false,
2017
+ "single_word": false,
2018
+ "special": true
2019
+ },
2020
+ "128252": {
2021
+ "content": "<|reserved_special_token_244|>",
2022
+ "lstrip": false,
2023
+ "normalized": false,
2024
+ "rstrip": false,
2025
+ "single_word": false,
2026
+ "special": true
2027
+ },
2028
+ "128253": {
2029
+ "content": "<|reserved_special_token_245|>",
2030
+ "lstrip": false,
2031
+ "normalized": false,
2032
+ "rstrip": false,
2033
+ "single_word": false,
2034
+ "special": true
2035
+ },
2036
+ "128254": {
2037
+ "content": "<|reserved_special_token_246|>",
2038
+ "lstrip": false,
2039
+ "normalized": false,
2040
+ "rstrip": false,
2041
+ "single_word": false,
2042
+ "special": true
2043
+ },
2044
+ "128255": {
2045
+ "content": "<|reserved_special_token_247|>",
2046
+ "lstrip": false,
2047
+ "normalized": false,
2048
+ "rstrip": false,
2049
+ "single_word": false,
2050
+ "special": true
2051
+ },
2052
+ "128256": {
2053
+ "content": "<img>",
2054
+ "lstrip": false,
2055
+ "normalized": false,
2056
+ "rstrip": false,
2057
+ "single_word": false,
2058
+ "special": true
2059
+ },
2060
+ "128257": {
2061
+ "content": "</img>",
2062
+ "lstrip": false,
2063
+ "normalized": false,
2064
+ "rstrip": false,
2065
+ "single_word": false,
2066
+ "special": true
2067
+ },
2068
+ "128258": {
2069
+ "content": "<IMG_CONTEXT>",
2070
+ "lstrip": false,
2071
+ "normalized": false,
2072
+ "rstrip": false,
2073
+ "single_word": false,
2074
+ "special": true
2075
+ },
2076
+ "128259": {
2077
+ "content": "<quad>",
2078
+ "lstrip": false,
2079
+ "normalized": false,
2080
+ "rstrip": false,
2081
+ "single_word": false,
2082
+ "special": true
2083
+ },
2084
+ "128260": {
2085
+ "content": "</quad>",
2086
+ "lstrip": false,
2087
+ "normalized": false,
2088
+ "rstrip": false,
2089
+ "single_word": false,
2090
+ "special": true
2091
+ },
2092
+ "128261": {
2093
+ "content": "<ref>",
2094
+ "lstrip": false,
2095
+ "normalized": false,
2096
+ "rstrip": false,
2097
+ "single_word": false,
2098
+ "special": true
2099
+ },
2100
+ "128262": {
2101
+ "content": "</ref>",
2102
+ "lstrip": false,
2103
+ "normalized": false,
2104
+ "rstrip": false,
2105
+ "single_word": false,
2106
+ "special": true
2107
+ },
2108
+ "128263": {
2109
+ "content": "<box>",
2110
+ "lstrip": false,
2111
+ "normalized": false,
2112
+ "rstrip": false,
2113
+ "single_word": false,
2114
+ "special": true
2115
+ },
2116
+ "128264": {
2117
+ "content": "</box>",
2118
+ "lstrip": false,
2119
+ "normalized": false,
2120
+ "rstrip": false,
2121
+ "single_word": false,
2122
+ "special": true
2123
+ },
2124
+ "128265": {
2125
+ "content": "<interval>",
2126
+ "lstrip": false,
2127
+ "normalized": false,
2128
+ "rstrip": false,
2129
+ "single_word": false,
2130
+ "special": true
2131
+ },
2132
+ "128266": {
2133
+ "content": "</interval>",
2134
+ "lstrip": false,
2135
+ "normalized": false,
2136
+ "rstrip": false,
2137
+ "single_word": false,
2138
+ "special": true
2139
+ }
2140
+ },
2141
+ "bos_token": "<|begin_of_text|>",
2142
+ "chat_template": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n",
2143
+ "clean_up_tokenization_spaces": true,
2144
+ "eos_token": "<|eot_id|>",
2145
+ "model_input_names": [
2146
+ "input_ids",
2147
+ "attention_mask"
2148
+ ],
2149
+ "model_max_length": 16384,
2150
+ "pad_token": "<|finetune_right_pad_id|>",
2151
+ "tokenizer_class": "PreTrainedTokenizerFast"
2152
+ }
train_results.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 1.0,
3
+ "train_loss": 0.06401226358036113,
4
+ "train_runtime": 6205.6143,
5
+ "train_samples": 9129380,
6
+ "train_samples_per_second": 1471.148,
7
+ "train_steps_per_second": 1.874
8
+ }
trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fee6e377288b050dc4aeeafc8768f1c26beb461a3b3da0d8eace2506afa5fda8
3
+ size 6264