Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- chat_template.jinja +397 -0
- config.json +444 -0
- generation_config.json +6 -0
- model-00001-of-00003.safetensors +3 -0
- model-00002-of-00003.safetensors +3 -0
- model-00003-of-00003.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_gpt_oss.py +893 -0
- quantization_config.json +372 -0
- special_tokens_map.json +23 -0
- tokenizer.json +3 -0
- tokenizer_config.json +183 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
chat_template.jinja
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{#-
|
2 |
+
In addition to the normal inputs of `messages` and `tools`, this template also accepts the
|
3 |
+
following kwargs:
|
4 |
+
- "builtin_tools": A list, can contain "browser" and/or "python".
|
5 |
+
- "model_identity": A string that optionally describes the model identity.
|
6 |
+
- "reasoning_effort": A string that describes the reasoning effort, defaults to "medium".
|
7 |
+
#}
|
8 |
+
|
9 |
+
{#- Tool Definition Rendering ============================================== #}
|
10 |
+
{%- macro render_typescript_type(param_spec, required_params, is_nullable=false) -%}
|
11 |
+
{%- if param_spec.type == "array" -%}
|
12 |
+
{%- if param_spec['items'] -%}
|
13 |
+
{%- if param_spec['items']['type'] == "string" -%}
|
14 |
+
{{- "string[]" }}
|
15 |
+
{%- elif param_spec['items']['type'] == "number" -%}
|
16 |
+
{{- "number[]" }}
|
17 |
+
{%- elif param_spec['items']['type'] == "integer" -%}
|
18 |
+
{{- "number[]" }}
|
19 |
+
{%- elif param_spec['items']['type'] == "boolean" -%}
|
20 |
+
{{- "boolean[]" }}
|
21 |
+
{%- else -%}
|
22 |
+
{%- set inner_type = render_typescript_type(param_spec['items'], required_params) -%}
|
23 |
+
{%- if inner_type == "object | object" or inner_type|length > 50 -%}
|
24 |
+
{{- "any[]" }}
|
25 |
+
{%- else -%}
|
26 |
+
{{- inner_type + "[]" }}
|
27 |
+
{%- endif -%}
|
28 |
+
{%- endif -%}
|
29 |
+
{%- if param_spec.nullable -%}
|
30 |
+
{{- " | null" }}
|
31 |
+
{%- endif -%}
|
32 |
+
{%- else -%}
|
33 |
+
{{- "any[]" }}
|
34 |
+
{%- if param_spec.nullable -%}
|
35 |
+
{{- " | null" }}
|
36 |
+
{%- endif -%}
|
37 |
+
{%- endif -%}
|
38 |
+
{%- elif param_spec.type is defined and param_spec.type is iterable and param_spec.type is not string and param_spec.type is not mapping and param_spec.type[0] is defined -%}
|
39 |
+
{#- Handle array of types like ["object", "object"] from Union[dict, list] #}
|
40 |
+
{%- if param_spec.type | length > 1 -%}
|
41 |
+
{{- param_spec.type | join(" | ") }}
|
42 |
+
{%- else -%}
|
43 |
+
{{- param_spec.type[0] }}
|
44 |
+
{%- endif -%}
|
45 |
+
{%- elif param_spec.oneOf -%}
|
46 |
+
{#- Handle oneOf schemas - check for complex unions and fallback to any #}
|
47 |
+
{%- set has_object_variants = false -%}
|
48 |
+
{%- for variant in param_spec.oneOf -%}
|
49 |
+
{%- if variant.type == "object" -%}
|
50 |
+
{%- set has_object_variants = true -%}
|
51 |
+
{%- endif -%}
|
52 |
+
{%- endfor -%}
|
53 |
+
{%- if has_object_variants and param_spec.oneOf|length > 1 -%}
|
54 |
+
{{- "any" }}
|
55 |
+
{%- else -%}
|
56 |
+
{%- for variant in param_spec.oneOf -%}
|
57 |
+
{{- render_typescript_type(variant, required_params) -}}
|
58 |
+
{%- if variant.description %}
|
59 |
+
{{- "// " + variant.description }}
|
60 |
+
{%- endif -%}
|
61 |
+
{%- if variant.default is defined %}
|
62 |
+
{{ "// default: " + variant.default|tojson }}
|
63 |
+
{%- endif -%}
|
64 |
+
{%- if not loop.last %}
|
65 |
+
{{- " | " }}
|
66 |
+
{% endif -%}
|
67 |
+
{%- endfor -%}
|
68 |
+
{%- endif -%}
|
69 |
+
{%- elif param_spec.type == "string" -%}
|
70 |
+
{%- if param_spec.enum -%}
|
71 |
+
{{- '"' + param_spec.enum|join('" | "') + '"' -}}
|
72 |
+
{%- else -%}
|
73 |
+
{{- "string" }}
|
74 |
+
{%- if param_spec.nullable %}
|
75 |
+
{{- " | null" }}
|
76 |
+
{%- endif -%}
|
77 |
+
{%- endif -%}
|
78 |
+
{%- elif param_spec.type == "number" -%}
|
79 |
+
{{- "number" }}
|
80 |
+
{%- elif param_spec.type == "integer" -%}
|
81 |
+
{{- "number" }}
|
82 |
+
{%- elif param_spec.type == "boolean" -%}
|
83 |
+
{{- "boolean" }}
|
84 |
+
|
85 |
+
{%- elif param_spec.type == "object" -%}
|
86 |
+
{%- if param_spec.properties -%}
|
87 |
+
{{- "{
|
88 |
+
" }}
|
89 |
+
{%- for prop_name, prop_spec in param_spec.properties.items() -%}
|
90 |
+
{{- prop_name -}}
|
91 |
+
{%- if prop_name not in (param_spec.required or []) -%}
|
92 |
+
{{- "?" }}
|
93 |
+
{%- endif -%}
|
94 |
+
{{- ": " }}
|
95 |
+
{{ render_typescript_type(prop_spec, param_spec.required or []) }}
|
96 |
+
{%- if not loop.last -%}
|
97 |
+
{{-", " }}
|
98 |
+
{%- endif -%}
|
99 |
+
{%- endfor -%}
|
100 |
+
{{- "}" }}
|
101 |
+
{%- else -%}
|
102 |
+
{{- "object" }}
|
103 |
+
{%- endif -%}
|
104 |
+
{%- else -%}
|
105 |
+
{{- "any" }}
|
106 |
+
{%- endif -%}
|
107 |
+
{%- endmacro -%}
|
108 |
+
|
109 |
+
{%- macro render_tool_namespace(namespace_name, tools) -%}
|
110 |
+
{{- "## " + namespace_name + "
|
111 |
+
|
112 |
+
" }}
|
113 |
+
{{- "namespace " + namespace_name + " {
|
114 |
+
|
115 |
+
" }}
|
116 |
+
{%- for tool in tools %}
|
117 |
+
{%- set tool = tool.function %}
|
118 |
+
{{- "// " + tool.description + "
|
119 |
+
" }}
|
120 |
+
{{- "type "+ tool.name + " = " }}
|
121 |
+
{%- if tool.parameters and tool.parameters.properties %}
|
122 |
+
{{- "(_: {
|
123 |
+
" }}
|
124 |
+
{%- for param_name, param_spec in tool.parameters.properties.items() %}
|
125 |
+
{%- if param_spec.description %}
|
126 |
+
{{- "// " + param_spec.description + "
|
127 |
+
" }}
|
128 |
+
{%- endif %}
|
129 |
+
{{- param_name }}
|
130 |
+
{%- if param_name not in (tool.parameters.required or []) -%}
|
131 |
+
{{- "?" }}
|
132 |
+
{%- endif -%}
|
133 |
+
{{- ": " }}
|
134 |
+
{{- render_typescript_type(param_spec, tool.parameters.required or []) }}
|
135 |
+
{%- if param_spec.default is defined -%}
|
136 |
+
{%- if param_spec.enum %}
|
137 |
+
{{- ", // default: " + param_spec.default }}
|
138 |
+
{%- elif param_spec.oneOf %}
|
139 |
+
{{- "// default: " + param_spec.default }}
|
140 |
+
{%- else %}
|
141 |
+
{{- ", // default: " + param_spec.default|tojson }}
|
142 |
+
{%- endif -%}
|
143 |
+
{%- endif -%}
|
144 |
+
{%- if not loop.last %}
|
145 |
+
{{- ",
|
146 |
+
" }}
|
147 |
+
{%- else %}
|
148 |
+
{{- "
|
149 |
+
" }}
|
150 |
+
{%- endif -%}
|
151 |
+
{%- endfor %}
|
152 |
+
{{- "}) => any;
|
153 |
+
|
154 |
+
" }}
|
155 |
+
{%- else -%}
|
156 |
+
{{- "() => any;
|
157 |
+
|
158 |
+
" }}
|
159 |
+
{%- endif -%}
|
160 |
+
{%- endfor %}
|
161 |
+
{{- "} // namespace " + namespace_name }}
|
162 |
+
{%- endmacro -%}
|
163 |
+
|
164 |
+
{%- macro render_builtin_tools(browser_tool, python_tool) -%}
|
165 |
+
{%- if browser_tool %}
|
166 |
+
{{- "## browser
|
167 |
+
|
168 |
+
" }}
|
169 |
+
{{- "// Tool for browsing.
|
170 |
+
" }}
|
171 |
+
{{- "// The `cursor` appears in brackets before each browsing display: `[{cursor}]`.
|
172 |
+
" }}
|
173 |
+
{{- "// Cite information from the tool using the following format:
|
174 |
+
" }}
|
175 |
+
{{- "// `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.
|
176 |
+
" }}
|
177 |
+
{{- "// Do not quote more than 10 words directly from the tool output.
|
178 |
+
" }}
|
179 |
+
{{- "// sources=web (default: web)
|
180 |
+
" }}
|
181 |
+
{{- "namespace browser {
|
182 |
+
|
183 |
+
" }}
|
184 |
+
{{- "// Searches for information related to `query` and displays `topn` results.
|
185 |
+
" }}
|
186 |
+
{{- "type search = (_: {
|
187 |
+
" }}
|
188 |
+
{{- "query: string,
|
189 |
+
" }}
|
190 |
+
{{- "topn?: number, // default: 10
|
191 |
+
" }}
|
192 |
+
{{- "source?: string,
|
193 |
+
" }}
|
194 |
+
{{- "}) => any;
|
195 |
+
|
196 |
+
" }}
|
197 |
+
{{- "// Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.
|
198 |
+
" }}
|
199 |
+
{{- "// Valid link ids are displayed with the formatting: `【{id}†.*】`.
|
200 |
+
" }}
|
201 |
+
{{- "// If `cursor` is not provided, the most recent page is implied.
|
202 |
+
" }}
|
203 |
+
{{- "// If `id` is a string, it is treated as a fully qualified URL associated with `source`.
|
204 |
+
" }}
|
205 |
+
{{- "// If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.
|
206 |
+
" }}
|
207 |
+
{{- "// Use this function without `id` to scroll to a new location of an opened page.
|
208 |
+
" }}
|
209 |
+
{{- "type open = (_: {
|
210 |
+
" }}
|
211 |
+
{{- "id?: number | string, // default: -1
|
212 |
+
" }}
|
213 |
+
{{- "cursor?: number, // default: -1
|
214 |
+
" }}
|
215 |
+
{{- "loc?: number, // default: -1
|
216 |
+
" }}
|
217 |
+
{{- "num_lines?: number, // default: -1
|
218 |
+
" }}
|
219 |
+
{{- "view_source?: boolean, // default: false
|
220 |
+
" }}
|
221 |
+
{{- "source?: string,
|
222 |
+
" }}
|
223 |
+
{{- "}) => any;
|
224 |
+
|
225 |
+
" }}
|
226 |
+
{{- "// Finds exact matches of `pattern` in the current page, or the page given by `cursor`.
|
227 |
+
" }}
|
228 |
+
{{- "type find = (_: {
|
229 |
+
" }}
|
230 |
+
{{- "pattern: string,
|
231 |
+
" }}
|
232 |
+
{{- "cursor?: number, // default: -1
|
233 |
+
" }}
|
234 |
+
{{- "}) => any;
|
235 |
+
|
236 |
+
" }}
|
237 |
+
{{- "} // namespace browser
|
238 |
+
|
239 |
+
" }}
|
240 |
+
{%- endif -%}
|
241 |
+
|
242 |
+
{%- if python_tool %}
|
243 |
+
{{- "## python
|
244 |
+
|
245 |
+
" }}
|
246 |
+
{{- "Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).
|
247 |
+
|
248 |
+
" }}
|
249 |
+
{{- "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.
|
250 |
+
|
251 |
+
" }}
|
252 |
+
{%- endif -%}
|
253 |
+
{%- endmacro -%}
|
254 |
+
|
255 |
+
{#- System Message Construction ============================================ #}
|
256 |
+
{%- macro build_system_message() -%}
|
257 |
+
{%- if model_identity is not defined %}
|
258 |
+
{%- set model_identity = "You are ChatGPT, a large language model trained by OpenAI." %}
|
259 |
+
{%- endif %}
|
260 |
+
{{- model_identity + "
|
261 |
+
" }}
|
262 |
+
{{- "Knowledge cutoff: 2024-06
|
263 |
+
" }}
|
264 |
+
{{- "Current date: " + strftime_now("%Y-%m-%d") + "
|
265 |
+
|
266 |
+
" }}
|
267 |
+
{%- if reasoning_effort is not defined %}
|
268 |
+
{%- set reasoning_effort = "medium" %}
|
269 |
+
{%- endif %}
|
270 |
+
{{- "Reasoning: " + reasoning_effort + "
|
271 |
+
|
272 |
+
" }}
|
273 |
+
{%- if builtin_tools %}
|
274 |
+
{{- "# Tools
|
275 |
+
|
276 |
+
" }}
|
277 |
+
{%- set available_builtin_tools = namespace(browser=false, python=false) %}
|
278 |
+
{%- for tool in builtin_tools %}
|
279 |
+
{%- if tool == "browser" %}
|
280 |
+
{%- set available_builtin_tools.browser = true %}
|
281 |
+
{%- elif tool == "python" %}
|
282 |
+
{%- set available_builtin_tools.python = true %}
|
283 |
+
{%- endif %}
|
284 |
+
{%- endfor %}
|
285 |
+
{{- render_builtin_tools(available_builtin_tools.browser, available_builtin_tools.python) }}
|
286 |
+
{%- endif -%}
|
287 |
+
{{- "# Valid channels: analysis, commentary, final. Channel must be included for every message." }}
|
288 |
+
{%- if tools -%}
|
289 |
+
{{- "
|
290 |
+
Calls to these tools must go to the commentary channel: 'functions'." }}
|
291 |
+
{%- endif -%}
|
292 |
+
{%- endmacro -%}
|
293 |
+
|
294 |
+
{#- Main Template Logic ================================================= #}
|
295 |
+
{#- Set defaults #}
|
296 |
+
|
297 |
+
{#- Render system message #}
|
298 |
+
{{- "<|start|>system<|message|>" }}
|
299 |
+
{{- build_system_message() }}
|
300 |
+
{{- "<|end|>" }}
|
301 |
+
|
302 |
+
{#- Extract developer message #}
|
303 |
+
{%- if messages[0].role == "developer" or messages[0].role == "system" %}
|
304 |
+
{%- set developer_message = messages[0].content %}
|
305 |
+
{%- set loop_messages = messages[1:] %}
|
306 |
+
{%- else %}
|
307 |
+
{%- set developer_message = "" %}
|
308 |
+
{%- set loop_messages = messages %}
|
309 |
+
{%- endif %}
|
310 |
+
|
311 |
+
{#- Render developer message #}
|
312 |
+
{%- if developer_message or tools %}
|
313 |
+
{{- "<|start|>developer<|message|>" }}
|
314 |
+
{%- if developer_message %}
|
315 |
+
{{- "# Instructions
|
316 |
+
|
317 |
+
" }}
|
318 |
+
{{- developer_message }}
|
319 |
+
{%- endif %}
|
320 |
+
{%- if tools -%}
|
321 |
+
{{- "
|
322 |
+
|
323 |
+
" }}
|
324 |
+
{{- "# Tools
|
325 |
+
|
326 |
+
" }}
|
327 |
+
{{- render_tool_namespace("functions", tools) }}
|
328 |
+
{%- endif -%}
|
329 |
+
{{- "<|end|>" }}
|
330 |
+
{%- endif %}
|
331 |
+
|
332 |
+
{#- Render messages #}
|
333 |
+
{%- set last_tool_call = namespace(name=none) %}
|
334 |
+
{%- for message in loop_messages -%}
|
335 |
+
{#- At this point only assistant/user/tool messages should remain #}
|
336 |
+
{%- if message.role == 'assistant' -%}
|
337 |
+
{#- Checks to ensure the messages are being passed in the format we expect #}
|
338 |
+
{%- if "content" in message %}
|
339 |
+
{%- if "<|channel|>analysis<|message|>" in message.content or "<|channel|>final<|message|>" in message.content %}
|
340 |
+
{{- raise_exception("You have passed a message containing <|channel|> tags in the content field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.") }}
|
341 |
+
{%- endif %}
|
342 |
+
{%- endif %}
|
343 |
+
{%- if "thinking" in message %}
|
344 |
+
{%- if "<|channel|>analysis<|message|>" in message.thinking or "<|channel|>final<|message|>" in message.thinking %}
|
345 |
+
{{- raise_exception("You have passed a message containing <|channel|> tags in the thinking field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.") }}
|
346 |
+
{%- endif %}
|
347 |
+
{%- endif %}
|
348 |
+
{%- if "tool_calls" in message %}
|
349 |
+
{#- We assume max 1 tool call per message, and so we infer the tool call name #}
|
350 |
+
{#- in "tool" messages from the most recent assistant tool call name #}
|
351 |
+
{%- set tool_call = message.tool_calls[0] %}
|
352 |
+
{%- if tool_call.function %}
|
353 |
+
{%- set tool_call = tool_call.function %}
|
354 |
+
{%- endif %}
|
355 |
+
{%- if message.content and message.thinking %}
|
356 |
+
{{- raise_exception("Cannot pass both content and thinking in an assistant message with tool calls! Put the analysis message in one or the other, but not both.") }}
|
357 |
+
{%- elif message.content %}
|
358 |
+
{{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }}
|
359 |
+
{%- elif message.thinking %}
|
360 |
+
{{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
|
361 |
+
{%- endif %}
|
362 |
+
{{- "<|start|>assistant to=" }}
|
363 |
+
{{- "functions." + tool_call.name + "<|channel|>commentary " }}
|
364 |
+
{{- (tool_call.content_type if tool_call.content_type is defined else "json") + "<|message|>" }}
|
365 |
+
{{- tool_call.arguments|tojson }}
|
366 |
+
{{- "<|call|>" }}
|
367 |
+
{%- set last_tool_call.name = tool_call.name %}
|
368 |
+
{%- elif loop.last and not add_generation_prompt %}
|
369 |
+
{#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #}
|
370 |
+
{#- This is a situation that should only occur in training, never in inference. #}
|
371 |
+
{%- if "thinking" in message %}
|
372 |
+
{{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
|
373 |
+
{%- endif %}
|
374 |
+
{#- <|return|> indicates the end of generation, but <|end|> does not #}
|
375 |
+
{#- <|return|> should never be an input to the model, but we include it as the final token #}
|
376 |
+
{#- when training, so the model learns to emit it. #}
|
377 |
+
{{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|return|>" }}
|
378 |
+
{%- else %}
|
379 |
+
{#- CoT is dropped during all previous turns, so we never render it for inference #}
|
380 |
+
{{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }}
|
381 |
+
{%- set last_tool_call.name = none %}
|
382 |
+
{%- endif %}
|
383 |
+
{%- elif message.role == 'tool' -%}
|
384 |
+
{%- if last_tool_call.name is none %}
|
385 |
+
{{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }}
|
386 |
+
{%- endif %}
|
387 |
+
{{- "<|start|>functions." + last_tool_call.name }}
|
388 |
+
{{- " to=assistant<|channel|>commentary<|message|>" + message.content|tojson + "<|end|>" }}
|
389 |
+
{%- elif message.role == 'user' -%}
|
390 |
+
{{- "<|start|>user<|message|>" + message.content + "<|end|>" }}
|
391 |
+
{%- endif -%}
|
392 |
+
{%- endfor -%}
|
393 |
+
|
394 |
+
{#- Generation prompt #}
|
395 |
+
{%- if add_generation_prompt -%}
|
396 |
+
<|start|>assistant
|
397 |
+
{%- endif -%}
|
config.json
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"GptOssForCausalLM"
|
4 |
+
],
|
5 |
+
"attention_bias": true,
|
6 |
+
"attention_dropout": 0.0,
|
7 |
+
"auto_map": {
|
8 |
+
"AutoModel": "modeling_gpt_oss.GptOssModel",
|
9 |
+
"AutoModelForCausalLM": "modeling_gpt_oss.GptOssForCausalLM"
|
10 |
+
},
|
11 |
+
"eos_token_id": 200002,
|
12 |
+
"experts_per_token": 4,
|
13 |
+
"head_dim": 64,
|
14 |
+
"hidden_act": "silu",
|
15 |
+
"hidden_size": 2880,
|
16 |
+
"initial_context_length": 4096,
|
17 |
+
"initializer_range": 0.02,
|
18 |
+
"intermediate_size": 2880,
|
19 |
+
"layer_types": [
|
20 |
+
"sliding_attention",
|
21 |
+
"full_attention",
|
22 |
+
"sliding_attention",
|
23 |
+
"full_attention",
|
24 |
+
"sliding_attention",
|
25 |
+
"full_attention",
|
26 |
+
"sliding_attention",
|
27 |
+
"full_attention",
|
28 |
+
"sliding_attention",
|
29 |
+
"full_attention",
|
30 |
+
"sliding_attention",
|
31 |
+
"full_attention",
|
32 |
+
"sliding_attention",
|
33 |
+
"full_attention",
|
34 |
+
"sliding_attention",
|
35 |
+
"full_attention",
|
36 |
+
"sliding_attention",
|
37 |
+
"full_attention",
|
38 |
+
"sliding_attention",
|
39 |
+
"full_attention",
|
40 |
+
"sliding_attention",
|
41 |
+
"full_attention",
|
42 |
+
"sliding_attention",
|
43 |
+
"full_attention"
|
44 |
+
],
|
45 |
+
"max_position_embeddings": 131072,
|
46 |
+
"model_type": "gpt_oss",
|
47 |
+
"num_attention_heads": 64,
|
48 |
+
"num_experts_per_tok": 4,
|
49 |
+
"num_hidden_layers": 24,
|
50 |
+
"num_key_value_heads": 8,
|
51 |
+
"num_local_experts": 32,
|
52 |
+
"output_router_logits": false,
|
53 |
+
"pad_token_id": 199999,
|
54 |
+
"quantization_config": {
|
55 |
+
"autoround_version": "0.6.1.dev",
|
56 |
+
"bits": 4,
|
57 |
+
"data_type": "int",
|
58 |
+
"extra_config": {
|
59 |
+
"model.layers.0.mlp.router.router": {
|
60 |
+
"bits": 16
|
61 |
+
},
|
62 |
+
"model.layers.0.self_attn.k_proj": {
|
63 |
+
"bits": 16
|
64 |
+
},
|
65 |
+
"model.layers.0.self_attn.o_proj": {
|
66 |
+
"bits": 16
|
67 |
+
},
|
68 |
+
"model.layers.0.self_attn.q_proj": {
|
69 |
+
"bits": 16
|
70 |
+
},
|
71 |
+
"model.layers.0.self_attn.v_proj": {
|
72 |
+
"bits": 16
|
73 |
+
},
|
74 |
+
"model.layers.1.mlp.router.router": {
|
75 |
+
"bits": 16
|
76 |
+
},
|
77 |
+
"model.layers.1.self_attn.k_proj": {
|
78 |
+
"bits": 16
|
79 |
+
},
|
80 |
+
"model.layers.1.self_attn.o_proj": {
|
81 |
+
"bits": 16
|
82 |
+
},
|
83 |
+
"model.layers.1.self_attn.q_proj": {
|
84 |
+
"bits": 16
|
85 |
+
},
|
86 |
+
"model.layers.1.self_attn.v_proj": {
|
87 |
+
"bits": 16
|
88 |
+
},
|
89 |
+
"model.layers.10.mlp.router.router": {
|
90 |
+
"bits": 16
|
91 |
+
},
|
92 |
+
"model.layers.10.self_attn.k_proj": {
|
93 |
+
"bits": 16
|
94 |
+
},
|
95 |
+
"model.layers.10.self_attn.o_proj": {
|
96 |
+
"bits": 16
|
97 |
+
},
|
98 |
+
"model.layers.10.self_attn.q_proj": {
|
99 |
+
"bits": 16
|
100 |
+
},
|
101 |
+
"model.layers.10.self_attn.v_proj": {
|
102 |
+
"bits": 16
|
103 |
+
},
|
104 |
+
"model.layers.11.mlp.router.router": {
|
105 |
+
"bits": 16
|
106 |
+
},
|
107 |
+
"model.layers.11.self_attn.k_proj": {
|
108 |
+
"bits": 16
|
109 |
+
},
|
110 |
+
"model.layers.11.self_attn.o_proj": {
|
111 |
+
"bits": 16
|
112 |
+
},
|
113 |
+
"model.layers.11.self_attn.q_proj": {
|
114 |
+
"bits": 16
|
115 |
+
},
|
116 |
+
"model.layers.11.self_attn.v_proj": {
|
117 |
+
"bits": 16
|
118 |
+
},
|
119 |
+
"model.layers.12.mlp.router.router": {
|
120 |
+
"bits": 16
|
121 |
+
},
|
122 |
+
"model.layers.12.self_attn.k_proj": {
|
123 |
+
"bits": 16
|
124 |
+
},
|
125 |
+
"model.layers.12.self_attn.o_proj": {
|
126 |
+
"bits": 16
|
127 |
+
},
|
128 |
+
"model.layers.12.self_attn.q_proj": {
|
129 |
+
"bits": 16
|
130 |
+
},
|
131 |
+
"model.layers.12.self_attn.v_proj": {
|
132 |
+
"bits": 16
|
133 |
+
},
|
134 |
+
"model.layers.13.mlp.router.router": {
|
135 |
+
"bits": 16
|
136 |
+
},
|
137 |
+
"model.layers.13.self_attn.k_proj": {
|
138 |
+
"bits": 16
|
139 |
+
},
|
140 |
+
"model.layers.13.self_attn.o_proj": {
|
141 |
+
"bits": 16
|
142 |
+
},
|
143 |
+
"model.layers.13.self_attn.q_proj": {
|
144 |
+
"bits": 16
|
145 |
+
},
|
146 |
+
"model.layers.13.self_attn.v_proj": {
|
147 |
+
"bits": 16
|
148 |
+
},
|
149 |
+
"model.layers.14.mlp.router.router": {
|
150 |
+
"bits": 16
|
151 |
+
},
|
152 |
+
"model.layers.14.self_attn.k_proj": {
|
153 |
+
"bits": 16
|
154 |
+
},
|
155 |
+
"model.layers.14.self_attn.o_proj": {
|
156 |
+
"bits": 16
|
157 |
+
},
|
158 |
+
"model.layers.14.self_attn.q_proj": {
|
159 |
+
"bits": 16
|
160 |
+
},
|
161 |
+
"model.layers.14.self_attn.v_proj": {
|
162 |
+
"bits": 16
|
163 |
+
},
|
164 |
+
"model.layers.15.mlp.router.router": {
|
165 |
+
"bits": 16
|
166 |
+
},
|
167 |
+
"model.layers.15.self_attn.k_proj": {
|
168 |
+
"bits": 16
|
169 |
+
},
|
170 |
+
"model.layers.15.self_attn.o_proj": {
|
171 |
+
"bits": 16
|
172 |
+
},
|
173 |
+
"model.layers.15.self_attn.q_proj": {
|
174 |
+
"bits": 16
|
175 |
+
},
|
176 |
+
"model.layers.15.self_attn.v_proj": {
|
177 |
+
"bits": 16
|
178 |
+
},
|
179 |
+
"model.layers.16.mlp.router.router": {
|
180 |
+
"bits": 16
|
181 |
+
},
|
182 |
+
"model.layers.16.self_attn.k_proj": {
|
183 |
+
"bits": 16
|
184 |
+
},
|
185 |
+
"model.layers.16.self_attn.o_proj": {
|
186 |
+
"bits": 16
|
187 |
+
},
|
188 |
+
"model.layers.16.self_attn.q_proj": {
|
189 |
+
"bits": 16
|
190 |
+
},
|
191 |
+
"model.layers.16.self_attn.v_proj": {
|
192 |
+
"bits": 16
|
193 |
+
},
|
194 |
+
"model.layers.17.mlp.router.router": {
|
195 |
+
"bits": 16
|
196 |
+
},
|
197 |
+
"model.layers.17.self_attn.k_proj": {
|
198 |
+
"bits": 16
|
199 |
+
},
|
200 |
+
"model.layers.17.self_attn.o_proj": {
|
201 |
+
"bits": 16
|
202 |
+
},
|
203 |
+
"model.layers.17.self_attn.q_proj": {
|
204 |
+
"bits": 16
|
205 |
+
},
|
206 |
+
"model.layers.17.self_attn.v_proj": {
|
207 |
+
"bits": 16
|
208 |
+
},
|
209 |
+
"model.layers.18.mlp.router.router": {
|
210 |
+
"bits": 16
|
211 |
+
},
|
212 |
+
"model.layers.18.self_attn.k_proj": {
|
213 |
+
"bits": 16
|
214 |
+
},
|
215 |
+
"model.layers.18.self_attn.o_proj": {
|
216 |
+
"bits": 16
|
217 |
+
},
|
218 |
+
"model.layers.18.self_attn.q_proj": {
|
219 |
+
"bits": 16
|
220 |
+
},
|
221 |
+
"model.layers.18.self_attn.v_proj": {
|
222 |
+
"bits": 16
|
223 |
+
},
|
224 |
+
"model.layers.19.mlp.router.router": {
|
225 |
+
"bits": 16
|
226 |
+
},
|
227 |
+
"model.layers.19.self_attn.k_proj": {
|
228 |
+
"bits": 16
|
229 |
+
},
|
230 |
+
"model.layers.19.self_attn.o_proj": {
|
231 |
+
"bits": 16
|
232 |
+
},
|
233 |
+
"model.layers.19.self_attn.q_proj": {
|
234 |
+
"bits": 16
|
235 |
+
},
|
236 |
+
"model.layers.19.self_attn.v_proj": {
|
237 |
+
"bits": 16
|
238 |
+
},
|
239 |
+
"model.layers.2.mlp.router.router": {
|
240 |
+
"bits": 16
|
241 |
+
},
|
242 |
+
"model.layers.2.self_attn.k_proj": {
|
243 |
+
"bits": 16
|
244 |
+
},
|
245 |
+
"model.layers.2.self_attn.o_proj": {
|
246 |
+
"bits": 16
|
247 |
+
},
|
248 |
+
"model.layers.2.self_attn.q_proj": {
|
249 |
+
"bits": 16
|
250 |
+
},
|
251 |
+
"model.layers.2.self_attn.v_proj": {
|
252 |
+
"bits": 16
|
253 |
+
},
|
254 |
+
"model.layers.20.mlp.router.router": {
|
255 |
+
"bits": 16
|
256 |
+
},
|
257 |
+
"model.layers.20.self_attn.k_proj": {
|
258 |
+
"bits": 16
|
259 |
+
},
|
260 |
+
"model.layers.20.self_attn.o_proj": {
|
261 |
+
"bits": 16
|
262 |
+
},
|
263 |
+
"model.layers.20.self_attn.q_proj": {
|
264 |
+
"bits": 16
|
265 |
+
},
|
266 |
+
"model.layers.20.self_attn.v_proj": {
|
267 |
+
"bits": 16
|
268 |
+
},
|
269 |
+
"model.layers.21.mlp.router.router": {
|
270 |
+
"bits": 16
|
271 |
+
},
|
272 |
+
"model.layers.21.self_attn.k_proj": {
|
273 |
+
"bits": 16
|
274 |
+
},
|
275 |
+
"model.layers.21.self_attn.o_proj": {
|
276 |
+
"bits": 16
|
277 |
+
},
|
278 |
+
"model.layers.21.self_attn.q_proj": {
|
279 |
+
"bits": 16
|
280 |
+
},
|
281 |
+
"model.layers.21.self_attn.v_proj": {
|
282 |
+
"bits": 16
|
283 |
+
},
|
284 |
+
"model.layers.22.mlp.router.router": {
|
285 |
+
"bits": 16
|
286 |
+
},
|
287 |
+
"model.layers.22.self_attn.k_proj": {
|
288 |
+
"bits": 16
|
289 |
+
},
|
290 |
+
"model.layers.22.self_attn.o_proj": {
|
291 |
+
"bits": 16
|
292 |
+
},
|
293 |
+
"model.layers.22.self_attn.q_proj": {
|
294 |
+
"bits": 16
|
295 |
+
},
|
296 |
+
"model.layers.22.self_attn.v_proj": {
|
297 |
+
"bits": 16
|
298 |
+
},
|
299 |
+
"model.layers.23.mlp.router.router": {
|
300 |
+
"bits": 16
|
301 |
+
},
|
302 |
+
"model.layers.23.self_attn.k_proj": {
|
303 |
+
"bits": 16
|
304 |
+
},
|
305 |
+
"model.layers.23.self_attn.o_proj": {
|
306 |
+
"bits": 16
|
307 |
+
},
|
308 |
+
"model.layers.23.self_attn.q_proj": {
|
309 |
+
"bits": 16
|
310 |
+
},
|
311 |
+
"model.layers.23.self_attn.v_proj": {
|
312 |
+
"bits": 16
|
313 |
+
},
|
314 |
+
"model.layers.3.mlp.router.router": {
|
315 |
+
"bits": 16
|
316 |
+
},
|
317 |
+
"model.layers.3.self_attn.k_proj": {
|
318 |
+
"bits": 16
|
319 |
+
},
|
320 |
+
"model.layers.3.self_attn.o_proj": {
|
321 |
+
"bits": 16
|
322 |
+
},
|
323 |
+
"model.layers.3.self_attn.q_proj": {
|
324 |
+
"bits": 16
|
325 |
+
},
|
326 |
+
"model.layers.3.self_attn.v_proj": {
|
327 |
+
"bits": 16
|
328 |
+
},
|
329 |
+
"model.layers.4.mlp.router.router": {
|
330 |
+
"bits": 16
|
331 |
+
},
|
332 |
+
"model.layers.4.self_attn.k_proj": {
|
333 |
+
"bits": 16
|
334 |
+
},
|
335 |
+
"model.layers.4.self_attn.o_proj": {
|
336 |
+
"bits": 16
|
337 |
+
},
|
338 |
+
"model.layers.4.self_attn.q_proj": {
|
339 |
+
"bits": 16
|
340 |
+
},
|
341 |
+
"model.layers.4.self_attn.v_proj": {
|
342 |
+
"bits": 16
|
343 |
+
},
|
344 |
+
"model.layers.5.mlp.router.router": {
|
345 |
+
"bits": 16
|
346 |
+
},
|
347 |
+
"model.layers.5.self_attn.k_proj": {
|
348 |
+
"bits": 16
|
349 |
+
},
|
350 |
+
"model.layers.5.self_attn.o_proj": {
|
351 |
+
"bits": 16
|
352 |
+
},
|
353 |
+
"model.layers.5.self_attn.q_proj": {
|
354 |
+
"bits": 16
|
355 |
+
},
|
356 |
+
"model.layers.5.self_attn.v_proj": {
|
357 |
+
"bits": 16
|
358 |
+
},
|
359 |
+
"model.layers.6.mlp.router.router": {
|
360 |
+
"bits": 16
|
361 |
+
},
|
362 |
+
"model.layers.6.self_attn.k_proj": {
|
363 |
+
"bits": 16
|
364 |
+
},
|
365 |
+
"model.layers.6.self_attn.o_proj": {
|
366 |
+
"bits": 16
|
367 |
+
},
|
368 |
+
"model.layers.6.self_attn.q_proj": {
|
369 |
+
"bits": 16
|
370 |
+
},
|
371 |
+
"model.layers.6.self_attn.v_proj": {
|
372 |
+
"bits": 16
|
373 |
+
},
|
374 |
+
"model.layers.7.mlp.router.router": {
|
375 |
+
"bits": 16
|
376 |
+
},
|
377 |
+
"model.layers.7.self_attn.k_proj": {
|
378 |
+
"bits": 16
|
379 |
+
},
|
380 |
+
"model.layers.7.self_attn.o_proj": {
|
381 |
+
"bits": 16
|
382 |
+
},
|
383 |
+
"model.layers.7.self_attn.q_proj": {
|
384 |
+
"bits": 16
|
385 |
+
},
|
386 |
+
"model.layers.7.self_attn.v_proj": {
|
387 |
+
"bits": 16
|
388 |
+
},
|
389 |
+
"model.layers.8.mlp.router.router": {
|
390 |
+
"bits": 16
|
391 |
+
},
|
392 |
+
"model.layers.8.self_attn.k_proj": {
|
393 |
+
"bits": 16
|
394 |
+
},
|
395 |
+
"model.layers.8.self_attn.o_proj": {
|
396 |
+
"bits": 16
|
397 |
+
},
|
398 |
+
"model.layers.8.self_attn.q_proj": {
|
399 |
+
"bits": 16
|
400 |
+
},
|
401 |
+
"model.layers.8.self_attn.v_proj": {
|
402 |
+
"bits": 16
|
403 |
+
},
|
404 |
+
"model.layers.9.mlp.router.router": {
|
405 |
+
"bits": 16
|
406 |
+
},
|
407 |
+
"model.layers.9.self_attn.k_proj": {
|
408 |
+
"bits": 16
|
409 |
+
},
|
410 |
+
"model.layers.9.self_attn.o_proj": {
|
411 |
+
"bits": 16
|
412 |
+
},
|
413 |
+
"model.layers.9.self_attn.q_proj": {
|
414 |
+
"bits": 16
|
415 |
+
},
|
416 |
+
"model.layers.9.self_attn.v_proj": {
|
417 |
+
"bits": 16
|
418 |
+
}
|
419 |
+
},
|
420 |
+
"group_size": 128,
|
421 |
+
"nsamples": 512,
|
422 |
+
"packing_format": "auto_round:auto_gptq",
|
423 |
+
"quant_method": "auto-round",
|
424 |
+
"sym": true
|
425 |
+
},
|
426 |
+
"rms_norm_eps": 1e-05,
|
427 |
+
"rope_scaling": {
|
428 |
+
"beta_fast": 32.0,
|
429 |
+
"beta_slow": 1.0,
|
430 |
+
"factor": 32.0,
|
431 |
+
"original_max_position_embeddings": 4096,
|
432 |
+
"rope_type": "yarn",
|
433 |
+
"truncate": false
|
434 |
+
},
|
435 |
+
"rope_theta": 150000,
|
436 |
+
"router_aux_loss_coef": 0.9,
|
437 |
+
"sliding_window": 128,
|
438 |
+
"swiglu_limit": 7.0,
|
439 |
+
"tie_word_embeddings": false,
|
440 |
+
"torch_dtype": "bfloat16",
|
441 |
+
"transformers_version": "4.55.0",
|
442 |
+
"use_cache": true,
|
443 |
+
"vocab_size": 201088
|
444 |
+
}
|
generation_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"eos_token_id": 200002,
|
4 |
+
"pad_token_id": 199999,
|
5 |
+
"transformers_version": "4.55.0"
|
6 |
+
}
|
model-00001-of-00003.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f33ef6959fa47f5cc284f01604b950539f3ad7bfed26088d9b41393f40872fc8
|
3 |
+
size 4998027832
|
model-00002-of-00003.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e6b8d50c63cfe36ce7caecc7b3375e846ef94bf52afe1748261528655f9343a2
|
3 |
+
size 4998707776
|
model-00003-of-00003.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a014ea3832980752f03a933deb1493c6ac54b999086d0f4395f13b327fadc3b0
|
3 |
+
size 3549668544
|
model.safetensors.index.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
modeling_gpt_oss.py
ADDED
@@ -0,0 +1,893 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Callable, Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from accelerate import init_empty_weights
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from transformers.cache_utils import Cache, DynamicCache
|
10 |
+
from transformers.generation import GenerationMixin
|
11 |
+
from transformers.integrations.hub_kernels import use_kernel_forward_from_hub
|
12 |
+
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
13 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
14 |
+
from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
15 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
16 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
17 |
+
from transformers.processing_utils import Unpack
|
18 |
+
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
|
19 |
+
from transformers.utils.generic import OutputRecorder, check_model_inputs
|
20 |
+
|
21 |
+
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
|
22 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
23 |
+
|
24 |
+
class GptOssConfig(PretrainedConfig):
|
25 |
+
r"""
|
26 |
+
This will yield a configuration to that of the BERT
|
27 |
+
[google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture.
|
28 |
+
|
29 |
+
"""
|
30 |
+
|
31 |
+
model_type = "gpt_oss"
|
32 |
+
base_model_pp_plan = {
|
33 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
34 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
35 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
36 |
+
}
|
37 |
+
base_model_tp_plan = {
|
38 |
+
"layers.*.self_attn.q_proj": "colwise",
|
39 |
+
"layers.*.self_attn.k_proj": "colwise",
|
40 |
+
"layers.*.self_attn.v_proj": "colwise",
|
41 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
42 |
+
"layers.*.self_attn.sinks": "local_rowwise",
|
43 |
+
"layers.*.mlp.experts": "gather",
|
44 |
+
"layers.*.mlp.router": "ep_router",
|
45 |
+
"layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
|
46 |
+
"layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm",
|
47 |
+
"layers.*.mlp.experts.down_proj": "grouped_gemm",
|
48 |
+
"layers.*.mlp.experts.down_proj_bias": "grouped_gemm",
|
49 |
+
}
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
num_hidden_layers: int = 36,
|
54 |
+
num_local_experts: int = 128,
|
55 |
+
vocab_size: int = 201088,
|
56 |
+
hidden_size: int = 2880,
|
57 |
+
intermediate_size: int = 2880,
|
58 |
+
head_dim: int = 64,
|
59 |
+
num_attention_heads: int = 64,
|
60 |
+
num_key_value_heads: int = 8,
|
61 |
+
sliding_window: int = 128,
|
62 |
+
rope_theta: float = 150000.0,
|
63 |
+
tie_word_embeddings=False,
|
64 |
+
hidden_act: str = "silu",
|
65 |
+
initializer_range: float = 0.02,
|
66 |
+
max_position_embeddings=131072,
|
67 |
+
rms_norm_eps: float = 1e-5,
|
68 |
+
rope_scaling={"rope_type": "yarn", "factor": 32.0, "beta_fast": 32.0, "beta_slow": 1.0, "truncate": False},
|
69 |
+
attention_dropout: float = 0.0,
|
70 |
+
num_experts_per_tok=4,
|
71 |
+
router_aux_loss_coef: float = 0.9,
|
72 |
+
output_router_logits=False,
|
73 |
+
use_cache=True,
|
74 |
+
layer_types=None,
|
75 |
+
**kwargs,
|
76 |
+
):
|
77 |
+
self.vocab_size = vocab_size
|
78 |
+
self.hidden_size = hidden_size
|
79 |
+
self.intermediate_size = intermediate_size
|
80 |
+
self.num_hidden_layers = num_hidden_layers
|
81 |
+
self.num_attention_heads = num_attention_heads
|
82 |
+
self.num_local_experts = num_local_experts
|
83 |
+
self.sliding_window = sliding_window
|
84 |
+
self.num_experts_per_tok = num_experts_per_tok
|
85 |
+
# for backward compatibility
|
86 |
+
if num_key_value_heads is None:
|
87 |
+
num_key_value_heads = num_attention_heads
|
88 |
+
|
89 |
+
self.num_key_value_heads = num_key_value_heads
|
90 |
+
self.hidden_act = hidden_act
|
91 |
+
self.initializer_range = initializer_range
|
92 |
+
self.rms_norm_eps = rms_norm_eps
|
93 |
+
self.rope_theta = rope_theta
|
94 |
+
self.rope_scaling = rope_scaling
|
95 |
+
self.attention_dropout = attention_dropout
|
96 |
+
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
97 |
+
self.layer_types = layer_types
|
98 |
+
if self.layer_types is None:
|
99 |
+
self.layer_types = [
|
100 |
+
"sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
|
101 |
+
]
|
102 |
+
layer_type_validation(self.layer_types)
|
103 |
+
|
104 |
+
# Validate the correctness of rotary position embeddings parameters
|
105 |
+
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
106 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
107 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
108 |
+
rope_config_validation(self)
|
109 |
+
|
110 |
+
self.attention_bias = True
|
111 |
+
self.max_position_embeddings = max_position_embeddings
|
112 |
+
self.router_aux_loss_coef = router_aux_loss_coef
|
113 |
+
self.output_router_logits = output_router_logits
|
114 |
+
self.use_cache = use_cache
|
115 |
+
super().__init__(
|
116 |
+
tie_word_embeddings=tie_word_embeddings,
|
117 |
+
**kwargs,
|
118 |
+
)
|
119 |
+
|
120 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
121 |
+
class GptOssRMSNorm(nn.Module):
|
122 |
+
def __init__(self, hidden_size, eps=1e-6):
|
123 |
+
"""
|
124 |
+
GptOssRMSNorm is equivalent to T5LayerNorm
|
125 |
+
"""
|
126 |
+
super().__init__()
|
127 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
128 |
+
self.variance_epsilon = eps
|
129 |
+
|
130 |
+
def forward(self, hidden_states):
|
131 |
+
input_dtype = hidden_states.dtype
|
132 |
+
hidden_states = hidden_states.to(torch.float32)
|
133 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
134 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
135 |
+
return (self.weight * hidden_states).to(input_dtype) # main diff with Llama
|
136 |
+
|
137 |
+
def extra_repr(self):
|
138 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
139 |
+
|
140 |
+
#
|
141 |
+
# class GptOssExperts(nn.Module):
|
142 |
+
# def __init__(self, config):
|
143 |
+
# super().__init__()
|
144 |
+
# self.intermediate_size = config.intermediate_size
|
145 |
+
# self.num_experts = config.num_local_experts
|
146 |
+
# self.hidden_size = config.hidden_size
|
147 |
+
# self.expert_dim = self.intermediate_size
|
148 |
+
# self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
149 |
+
# self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))
|
150 |
+
# self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
|
151 |
+
# self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))
|
152 |
+
# self.alpha = 1.702
|
153 |
+
# self.limit = 7.0
|
154 |
+
#
|
155 |
+
#
|
156 |
+
#
|
157 |
+
# def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
|
158 |
+
# """
|
159 |
+
# When training is is more efficient to just loop over the experts and compute the output for each expert
|
160 |
+
# as otherwise the memory would explode.
|
161 |
+
#
|
162 |
+
# For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
|
163 |
+
#
|
164 |
+
# Args:
|
165 |
+
# hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
|
166 |
+
# selected_experts (torch.Tensor): (batch_size * token_num, top_k)
|
167 |
+
# routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
|
168 |
+
# Returns:
|
169 |
+
# torch.Tensor
|
170 |
+
# """
|
171 |
+
# batch_size = hidden_states.shape[0]
|
172 |
+
# hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
|
173 |
+
# num_experts = routing_weights.shape[1]
|
174 |
+
# if self.training:
|
175 |
+
# next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
|
176 |
+
# with torch.no_grad():
|
177 |
+
# expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
|
178 |
+
# expert_mask = expert_mask.permute(2, 1, 0)
|
179 |
+
# # we sum on the top_k and on the sequence lenght to get which experts
|
180 |
+
# # are hit this time around
|
181 |
+
# expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
182 |
+
# for expert_idx in expert_hitted[:]:
|
183 |
+
# with torch.no_grad():
|
184 |
+
# _, token_idx = torch.where(expert_mask[expert_idx[0]])
|
185 |
+
# current_state = hidden_states[token_idx]
|
186 |
+
# gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
|
187 |
+
# gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
188 |
+
# gate = gate.clamp(min=None, max=self.limit)
|
189 |
+
# up = up.clamp(min=-self.limit, max=self.limit)
|
190 |
+
# glu = gate * torch.sigmoid(gate * self.alpha)
|
191 |
+
# gated_output = (up + 1) * glu
|
192 |
+
# out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
|
193 |
+
# weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
|
194 |
+
# next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
|
195 |
+
# next_states = next_states.view(batch_size, -1, self.hidden_size)
|
196 |
+
# else:
|
197 |
+
# hidden_states = hidden_states.repeat(num_experts, 1)
|
198 |
+
# hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
|
199 |
+
# gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
|
200 |
+
# gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
201 |
+
# gate = gate.clamp(min=None, max=self.limit)
|
202 |
+
# up = up.clamp(min=-self.limit, max=self.limit)
|
203 |
+
# glu = gate * torch.sigmoid(gate * self.alpha)
|
204 |
+
# next_states = torch.bmm(((up + 1) * glu), self.down_proj)
|
205 |
+
# next_states = next_states + self.down_proj_bias[..., None, :]
|
206 |
+
# next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
|
207 |
+
# next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
|
208 |
+
# next_states = next_states.sum(dim=0)
|
209 |
+
# return next_states
|
210 |
+
|
211 |
+
class GptOssExperts(nn.Module):
|
212 |
+
def __init__(self, config):
|
213 |
+
super().__init__()
|
214 |
+
self.intermediate_size = config.intermediate_size
|
215 |
+
self.num_experts = config.num_local_experts
|
216 |
+
self.hidden_size = config.hidden_size
|
217 |
+
self.expert_dim = self.intermediate_size
|
218 |
+
|
219 |
+
# 使用nn.Linear替代手动矩阵乘法
|
220 |
+
self.gate_up_projs = nn.ModuleList([
|
221 |
+
nn.Linear(self.hidden_size, 2 * self.expert_dim)
|
222 |
+
for _ in range(self.num_experts)
|
223 |
+
])
|
224 |
+
|
225 |
+
self.down_projs = nn.ModuleList([
|
226 |
+
nn.Linear(self.expert_dim, self.hidden_size)
|
227 |
+
for _ in range(self.num_experts)
|
228 |
+
])
|
229 |
+
|
230 |
+
self.alpha = 1.702
|
231 |
+
self.limit = 7.0
|
232 |
+
|
233 |
+
def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
|
234 |
+
batch_size = hidden_states.shape[0]
|
235 |
+
hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
|
236 |
+
num_experts = routing_weights.shape[1]
|
237 |
+
|
238 |
+
if self.training:
|
239 |
+
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
|
240 |
+
|
241 |
+
with torch.no_grad():
|
242 |
+
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
|
243 |
+
expert_mask = expert_mask.permute(2, 1, 0)
|
244 |
+
expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
245 |
+
|
246 |
+
for expert_idx in expert_hitted[:]:
|
247 |
+
with torch.no_grad():
|
248 |
+
_, token_idx = torch.where(expert_mask[expert_idx[0]])
|
249 |
+
|
250 |
+
current_state = hidden_states[token_idx]
|
251 |
+
|
252 |
+
# 使用Linear层替代手动矩阵乘法
|
253 |
+
gate_up = self.gate_up_projs[expert_idx](current_state)
|
254 |
+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
255 |
+
gate = gate.clamp(min=None, max=self.limit)
|
256 |
+
up = up.clamp(min=-self.limit, max=self.limit)
|
257 |
+
|
258 |
+
glu = gate * torch.sigmoid(gate * self.alpha)
|
259 |
+
gated_output = (up + 1) * glu
|
260 |
+
|
261 |
+
# 使用Linear层替代手动矩阵乘法
|
262 |
+
out = self.down_projs[expert_idx](gated_output)
|
263 |
+
|
264 |
+
weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
|
265 |
+
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
|
266 |
+
|
267 |
+
next_states = next_states.view(batch_size, -1, self.hidden_size)
|
268 |
+
else:
|
269 |
+
hidden_states = hidden_states.repeat(num_experts, 1)
|
270 |
+
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
|
271 |
+
|
272 |
+
# 批量处理所有专家
|
273 |
+
gate_up = torch.stack([proj(hidden_states[i]) for i, proj in enumerate(self.gate_up_projs)])
|
274 |
+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
275 |
+
gate = gate.clamp(min=None, max=self.limit)
|
276 |
+
up = up.clamp(min=-self.limit, max=self.limit)
|
277 |
+
|
278 |
+
glu = gate * torch.sigmoid(gate * self.alpha)
|
279 |
+
next_states = torch.stack([proj((up[i] + 1) * glu[i]) for i, proj in enumerate(self.down_projs)])
|
280 |
+
|
281 |
+
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
|
282 |
+
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
|
283 |
+
next_states = next_states.sum(dim=0)
|
284 |
+
|
285 |
+
return next_states
|
286 |
+
|
287 |
+
# class GptOssTopKRouter(nn.Module):
|
288 |
+
# def __init__(self, config):
|
289 |
+
# super().__init__()
|
290 |
+
# self.top_k = config.num_experts_per_tok
|
291 |
+
# self.num_experts = config.num_local_experts
|
292 |
+
# self.hidden_dim = config.hidden_size
|
293 |
+
# self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
|
294 |
+
# self.bias = nn.Parameter(torch.empty(self.num_experts))
|
295 |
+
#
|
296 |
+
# def forward(self, hidden_states):
|
297 |
+
# hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
298 |
+
# router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
|
299 |
+
# router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
|
300 |
+
# router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
|
301 |
+
# router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
|
302 |
+
# return router_scores, router_indices
|
303 |
+
|
304 |
+
|
305 |
+
class GptOssTopKRouter(nn.Module):
|
306 |
+
def __init__(self, config):
|
307 |
+
super().__init__()
|
308 |
+
self.top_k = config.num_experts_per_tok
|
309 |
+
self.num_experts = config.num_local_experts
|
310 |
+
self.hidden_dim = config.hidden_size
|
311 |
+
|
312 |
+
# 使用nn.Linear替代手动参数
|
313 |
+
self.router = nn.Linear(self.hidden_dim, self.num_experts)
|
314 |
+
|
315 |
+
def forward(self, hidden_states):
|
316 |
+
# 展平输入 (batch_size * seq_len, hidden_dim)
|
317 |
+
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
318 |
+
router_logits = self.router(hidden_states) # (num_tokens, num_experts)
|
319 |
+
|
320 |
+
router_top_value, router_indices = torch.topk(
|
321 |
+
router_logits,
|
322 |
+
self.top_k,
|
323 |
+
dim=-1
|
324 |
+
) # (num_tokens, top_k)
|
325 |
+
|
326 |
+
router_top_value = F.softmax(router_top_value, dim=-1, dtype=router_top_value.dtype)
|
327 |
+
|
328 |
+
router_scores = torch.zeros_like(router_logits).scatter_(
|
329 |
+
dim=1,
|
330 |
+
index=router_indices,
|
331 |
+
src=router_top_value
|
332 |
+
)
|
333 |
+
|
334 |
+
return router_scores, router_indices
|
335 |
+
|
336 |
+
|
337 |
+
|
338 |
+
|
339 |
+
@use_kernel_forward_from_hub("MegaBlocksMoeMLP")
|
340 |
+
class GptOssMLP(nn.Module):
|
341 |
+
def __init__(self, config):
|
342 |
+
super().__init__()
|
343 |
+
self.router = GptOssTopKRouter(config)
|
344 |
+
self.experts = GptOssExperts(config)
|
345 |
+
|
346 |
+
def forward(self, hidden_states):
|
347 |
+
router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len)
|
348 |
+
routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
|
349 |
+
return routed_out, router_scores
|
350 |
+
|
351 |
+
|
352 |
+
class GptOssRotaryEmbedding(nn.Module):
|
353 |
+
def __init__(self, config: GptOssConfig, device=None):
|
354 |
+
super().__init__()
|
355 |
+
# BC: "rope_type" was originally "type"
|
356 |
+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
357 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
358 |
+
else:
|
359 |
+
self.rope_type = "default"
|
360 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
361 |
+
self.original_max_seq_len = config.max_position_embeddings
|
362 |
+
|
363 |
+
self.config = config
|
364 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
365 |
+
|
366 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
367 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
368 |
+
self.original_inv_freq = self.inv_freq
|
369 |
+
|
370 |
+
@torch.no_grad()
|
371 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
372 |
+
def forward(self, x, position_ids):
|
373 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
374 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
375 |
+
|
376 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
377 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
378 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
379 |
+
emb = freqs
|
380 |
+
cos = emb.cos() * self.attention_scaling
|
381 |
+
sin = emb.sin() * self.attention_scaling
|
382 |
+
|
383 |
+
return cos.to(x.dtype), sin.to(x.dtype)
|
384 |
+
|
385 |
+
|
386 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
387 |
+
"""
|
388 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
389 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
390 |
+
"""
|
391 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
392 |
+
if n_rep == 1:
|
393 |
+
return hidden_states
|
394 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
395 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
396 |
+
|
397 |
+
|
398 |
+
def _apply_rotary_emb(
|
399 |
+
x: torch.Tensor,
|
400 |
+
cos: torch.Tensor,
|
401 |
+
sin: torch.Tensor,
|
402 |
+
) -> torch.Tensor:
|
403 |
+
first_half, second_half = torch.chunk(x, 2, dim=-1)
|
404 |
+
first_ = first_half * cos - second_half * sin
|
405 |
+
second_ = second_half * cos + first_half * sin
|
406 |
+
return torch.cat((first_, second_), dim=-1)
|
407 |
+
|
408 |
+
|
409 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
410 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
411 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
412 |
+
q_embed = _apply_rotary_emb(q, cos, sin)
|
413 |
+
k_embed = _apply_rotary_emb(k, cos, sin)
|
414 |
+
return q_embed, k_embed
|
415 |
+
|
416 |
+
|
417 |
+
def eager_attention_forward(
|
418 |
+
module: nn.Module,
|
419 |
+
query: torch.Tensor,
|
420 |
+
key: torch.Tensor,
|
421 |
+
value: torch.Tensor,
|
422 |
+
attention_mask: Optional[torch.Tensor],
|
423 |
+
scaling: float,
|
424 |
+
dropout: float = 0.0,
|
425 |
+
**kwargs,
|
426 |
+
):
|
427 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
428 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
429 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
430 |
+
if attention_mask is not None:
|
431 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
432 |
+
attn_weights = attn_weights + causal_mask
|
433 |
+
|
434 |
+
sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
|
435 |
+
combined_logits = torch.cat([attn_weights, sinks], dim=-1)
|
436 |
+
|
437 |
+
# This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16
|
438 |
+
# when training with bsz>1 we clamp max values.
|
439 |
+
|
440 |
+
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
|
441 |
+
probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
|
442 |
+
scores = probs[..., :-1] # we drop the sink here
|
443 |
+
attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
|
444 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
445 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
446 |
+
return attn_output, attn_weights
|
447 |
+
|
448 |
+
|
449 |
+
class GptOssAttention(nn.Module):
|
450 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
451 |
+
|
452 |
+
def __init__(self, config: GptOssConfig, layer_idx: int):
|
453 |
+
super().__init__()
|
454 |
+
self.config = config
|
455 |
+
self.layer_idx = layer_idx
|
456 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
457 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
458 |
+
self.scaling = self.head_dim**-0.5
|
459 |
+
self.attention_dropout = config.attention_dropout
|
460 |
+
self.is_causal = True
|
461 |
+
self.q_proj = nn.Linear(
|
462 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
463 |
+
)
|
464 |
+
self.k_proj = nn.Linear(
|
465 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
466 |
+
)
|
467 |
+
self.v_proj = nn.Linear(
|
468 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
469 |
+
)
|
470 |
+
self.o_proj = nn.Linear(
|
471 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
472 |
+
)
|
473 |
+
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
|
474 |
+
self.sinks = nn.Parameter(torch.empty(config.num_attention_heads))
|
475 |
+
|
476 |
+
def forward(
|
477 |
+
self,
|
478 |
+
hidden_states: torch.Tensor,
|
479 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
480 |
+
attention_mask: Optional[torch.Tensor],
|
481 |
+
past_key_value: Optional[Cache] = None,
|
482 |
+
cache_position: Optional[torch.LongTensor] = None,
|
483 |
+
**kwargs: Unpack[TransformersKwargs],
|
484 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
485 |
+
input_shape = hidden_states.shape[:-1]
|
486 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
487 |
+
|
488 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
489 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
490 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
491 |
+
|
492 |
+
cos, sin = position_embeddings
|
493 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
494 |
+
|
495 |
+
if past_key_value is not None:
|
496 |
+
cache_kwargs = {"cache_position": cache_position}
|
497 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
498 |
+
|
499 |
+
attention_interface: Callable = eager_attention_forward
|
500 |
+
if self.config._attn_implementation != "eager":
|
501 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
502 |
+
|
503 |
+
attn_output, attn_weights = attention_interface(
|
504 |
+
self,
|
505 |
+
query_states,
|
506 |
+
key_states,
|
507 |
+
value_states,
|
508 |
+
attention_mask,
|
509 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
510 |
+
scaling=self.scaling,
|
511 |
+
sliding_window=self.sliding_window,
|
512 |
+
s_aux=self.sinks, # diff with Llama
|
513 |
+
**kwargs,
|
514 |
+
)
|
515 |
+
|
516 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
517 |
+
attn_output = self.o_proj(attn_output)
|
518 |
+
return attn_output, attn_weights
|
519 |
+
|
520 |
+
|
521 |
+
class GptOssDecoderLayer(GradientCheckpointingLayer):
|
522 |
+
def __init__(self, config: GptOssConfig, layer_idx: int):
|
523 |
+
super().__init__()
|
524 |
+
self.hidden_size = config.hidden_size
|
525 |
+
self.self_attn = GptOssAttention(config=config, layer_idx=layer_idx)
|
526 |
+
self.mlp = GptOssMLP(config)
|
527 |
+
self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
528 |
+
self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
529 |
+
self.attention_type = config.layer_types[layer_idx]
|
530 |
+
|
531 |
+
def forward(
|
532 |
+
self,
|
533 |
+
hidden_states: torch.Tensor,
|
534 |
+
attention_mask: Optional[torch.Tensor] = None,
|
535 |
+
position_ids: Optional[torch.LongTensor] = None,
|
536 |
+
past_key_value: Optional[Cache] = None,
|
537 |
+
use_cache: Optional[bool] = False,
|
538 |
+
cache_position: Optional[torch.LongTensor] = None,
|
539 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
540 |
+
**kwargs: Unpack[TransformersKwargs],
|
541 |
+
) -> tuple[torch.Tensor]:
|
542 |
+
residual = hidden_states
|
543 |
+
hidden_states = self.input_layernorm(hidden_states)
|
544 |
+
# Self Attention
|
545 |
+
hidden_states, _ = self.self_attn(
|
546 |
+
hidden_states=hidden_states,
|
547 |
+
attention_mask=attention_mask,
|
548 |
+
position_ids=position_ids,
|
549 |
+
past_key_value=past_key_value,
|
550 |
+
use_cache=use_cache,
|
551 |
+
cache_position=cache_position,
|
552 |
+
position_embeddings=position_embeddings,
|
553 |
+
**kwargs,
|
554 |
+
)
|
555 |
+
hidden_states = residual + hidden_states
|
556 |
+
|
557 |
+
# Fully Connected
|
558 |
+
residual = hidden_states
|
559 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
560 |
+
hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores
|
561 |
+
hidden_states = residual + hidden_states
|
562 |
+
return hidden_states
|
563 |
+
|
564 |
+
|
565 |
+
@auto_docstring
|
566 |
+
class GptOssPreTrainedModel(PreTrainedModel):
|
567 |
+
config: GptOssConfig
|
568 |
+
base_model_prefix = "model"
|
569 |
+
supports_gradient_checkpointing = True
|
570 |
+
_no_split_modules = ["GptOssDecoderLayer"]
|
571 |
+
_skip_keys_device_placement = ["past_key_values"]
|
572 |
+
_supports_flash_attn = True
|
573 |
+
_supports_sdpa = False
|
574 |
+
_supports_flex_attn = True
|
575 |
+
|
576 |
+
_can_compile_fullgraph = True
|
577 |
+
_supports_attention_backend = True
|
578 |
+
_can_record_outputs = {
|
579 |
+
"router_logits": OutputRecorder(GptOssTopKRouter, index=0),
|
580 |
+
"hidden_states": GptOssDecoderLayer,
|
581 |
+
"attentions": GptOssAttention,
|
582 |
+
}
|
583 |
+
_keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]
|
584 |
+
_supports_flash_attention = False
|
585 |
+
_supports_flex_attention = False
|
586 |
+
|
587 |
+
def _init_weights(self, module):
|
588 |
+
std = self.config.initializer_range
|
589 |
+
if isinstance(module, nn.Linear):
|
590 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
591 |
+
if module.bias is not None:
|
592 |
+
module.bias.data.zero_()
|
593 |
+
elif isinstance(module, nn.Parameter):
|
594 |
+
module.data.normal_(mean=0.0, std=std)
|
595 |
+
elif isinstance(module, nn.Embedding):
|
596 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
597 |
+
if module.padding_idx is not None:
|
598 |
+
module.weight.data[module.padding_idx].zero_()
|
599 |
+
elif isinstance(module, GptOssRMSNorm):
|
600 |
+
module.weight.data.fill_(1.0)
|
601 |
+
# elif isinstance(module, GptOssExperts):##too slow
|
602 |
+
# for gate_up_proj in module.gate_up_projs:
|
603 |
+
# gate_up_proj.weight.normal_(mean=0.0, std=std)
|
604 |
+
# gate_up_proj.bias.data.zero_()
|
605 |
+
# for down_proj in module.down_projs:
|
606 |
+
# down_proj.weight.data.normal_(mean=0.0, std=std)
|
607 |
+
# down_proj.bias.data.zero_()
|
608 |
+
elif isinstance(module, GptOssAttention):
|
609 |
+
module.sinks.data.normal_(mean=0.0, std=std)
|
610 |
+
# elif isinstance(module, GptOssTopKRouter):
|
611 |
+
# module.weight.data.normal_(mean=0.0, std=std)
|
612 |
+
# module.bias.data.normal_(mean=0.0, std=std)
|
613 |
+
|
614 |
+
|
615 |
+
@auto_docstring
|
616 |
+
class GptOssModel(GptOssPreTrainedModel):
|
617 |
+
_no_split_modules = ["GptOssDecoderLayer"]
|
618 |
+
|
619 |
+
def __init__(self, config: GptOssConfig):
|
620 |
+
super().__init__(config)
|
621 |
+
self.padding_idx = config.pad_token_id
|
622 |
+
self.vocab_size = config.vocab_size
|
623 |
+
|
624 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
625 |
+
self.layers = nn.ModuleList(
|
626 |
+
[GptOssDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
627 |
+
)
|
628 |
+
self.norm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
629 |
+
self.rotary_emb = GptOssRotaryEmbedding(config=config)
|
630 |
+
self.gradient_checkpointing = False
|
631 |
+
|
632 |
+
# Initialize weights and apply final processing
|
633 |
+
self.post_init()
|
634 |
+
|
635 |
+
@check_model_inputs
|
636 |
+
@auto_docstring
|
637 |
+
def forward(
|
638 |
+
self,
|
639 |
+
input_ids: Optional[torch.LongTensor] = None,
|
640 |
+
attention_mask: Optional[torch.Tensor] = None,
|
641 |
+
position_ids: Optional[torch.LongTensor] = None,
|
642 |
+
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
643 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
644 |
+
use_cache: Optional[bool] = None,
|
645 |
+
cache_position: Optional[torch.LongTensor] = None,
|
646 |
+
**kwargs: Unpack[TransformersKwargs],
|
647 |
+
) -> MoeModelOutputWithPast:
|
648 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
649 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
650 |
+
|
651 |
+
if use_cache and past_key_values is None:
|
652 |
+
past_key_values = DynamicCache()
|
653 |
+
|
654 |
+
if inputs_embeds is None:
|
655 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
656 |
+
|
657 |
+
if cache_position is None:
|
658 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
659 |
+
cache_position = torch.arange(
|
660 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
661 |
+
)
|
662 |
+
if position_ids is None:
|
663 |
+
position_ids = cache_position.unsqueeze(0)
|
664 |
+
|
665 |
+
# It may already have been prepared by e.g. `generate`
|
666 |
+
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
667 |
+
mask_kwargs = {
|
668 |
+
"config": self.config,
|
669 |
+
"input_embeds": inputs_embeds,
|
670 |
+
"attention_mask": attention_mask,
|
671 |
+
"cache_position": cache_position,
|
672 |
+
"past_key_values": past_key_values,
|
673 |
+
}
|
674 |
+
causal_mask_mapping = {
|
675 |
+
"full_attention": create_causal_mask(**mask_kwargs),
|
676 |
+
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
677 |
+
}
|
678 |
+
|
679 |
+
hidden_states = inputs_embeds
|
680 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
681 |
+
|
682 |
+
for decoder_layer in self.layers:
|
683 |
+
hidden_states = decoder_layer(
|
684 |
+
hidden_states,
|
685 |
+
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
686 |
+
position_ids=position_ids,
|
687 |
+
past_key_value=past_key_values,
|
688 |
+
use_cache=use_cache,
|
689 |
+
cache_position=cache_position,
|
690 |
+
position_embeddings=position_embeddings,
|
691 |
+
**kwargs,
|
692 |
+
)
|
693 |
+
hidden_states = self.norm(hidden_states)
|
694 |
+
return MoeModelOutputWithPast(
|
695 |
+
last_hidden_state=hidden_states,
|
696 |
+
past_key_values=past_key_values,
|
697 |
+
)
|
698 |
+
|
699 |
+
|
700 |
+
def load_balancing_loss_func(
|
701 |
+
gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
|
702 |
+
num_experts: Optional[int] = None,
|
703 |
+
top_k=2,
|
704 |
+
attention_mask: Optional[torch.Tensor] = None,
|
705 |
+
) -> Union[torch.Tensor, int]:
|
706 |
+
r"""
|
707 |
+
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
708 |
+
|
709 |
+
See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
|
710 |
+
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
711 |
+
experts is too unbalanced.
|
712 |
+
|
713 |
+
Args:
|
714 |
+
gate_logits:
|
715 |
+
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
716 |
+
shape [batch_size X sequence_length, num_experts].
|
717 |
+
num_experts:
|
718 |
+
Number of experts
|
719 |
+
top_k:
|
720 |
+
The number of experts to route per-token, can be also interpreted as the `top-k` routing
|
721 |
+
parameter.
|
722 |
+
attention_mask (`torch.Tensor`, *optional*):
|
723 |
+
The attention_mask used in forward function
|
724 |
+
shape [batch_size X sequence_length] if not None.
|
725 |
+
|
726 |
+
Returns:
|
727 |
+
The auxiliary loss.
|
728 |
+
"""
|
729 |
+
if gate_logits is None or not isinstance(gate_logits, tuple):
|
730 |
+
return 0
|
731 |
+
|
732 |
+
if isinstance(gate_logits, tuple):
|
733 |
+
compute_device = gate_logits[0].device
|
734 |
+
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
735 |
+
|
736 |
+
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
737 |
+
|
738 |
+
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
739 |
+
|
740 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
741 |
+
|
742 |
+
if attention_mask is None:
|
743 |
+
# Compute the percentage of tokens routed to each experts
|
744 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
745 |
+
|
746 |
+
# Compute the average probability of routing to these experts
|
747 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
748 |
+
else:
|
749 |
+
batch_size, sequence_length = attention_mask.shape
|
750 |
+
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
|
751 |
+
|
752 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
753 |
+
expert_attention_mask = (
|
754 |
+
attention_mask[None, :, :, None, None]
|
755 |
+
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
|
756 |
+
.reshape(-1, top_k, num_experts)
|
757 |
+
.to(compute_device)
|
758 |
+
)
|
759 |
+
|
760 |
+
# Compute the percentage of tokens routed to each experts
|
761 |
+
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
|
762 |
+
expert_attention_mask, dim=0
|
763 |
+
)
|
764 |
+
|
765 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
766 |
+
router_per_expert_attention_mask = (
|
767 |
+
attention_mask[None, :, :, None]
|
768 |
+
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
769 |
+
.reshape(-1, num_experts)
|
770 |
+
.to(compute_device)
|
771 |
+
)
|
772 |
+
|
773 |
+
# Compute the average probability of routing to these experts
|
774 |
+
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
|
775 |
+
router_per_expert_attention_mask, dim=0
|
776 |
+
)
|
777 |
+
|
778 |
+
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
779 |
+
return overall_loss * num_experts
|
780 |
+
|
781 |
+
|
782 |
+
@auto_docstring
|
783 |
+
class GptOssForCausalLM(GptOssPreTrainedModel, GenerationMixin):
|
784 |
+
_tied_weights_keys = ["lm_head.weight"]
|
785 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
786 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
787 |
+
|
788 |
+
def __init__(self, config):
|
789 |
+
super().__init__(config)
|
790 |
+
self.model = GptOssModel(config)
|
791 |
+
self.vocab_size = config.vocab_size
|
792 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
793 |
+
self.router_aux_loss_coef = config.router_aux_loss_coef
|
794 |
+
self.num_experts = config.num_local_experts
|
795 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
796 |
+
|
797 |
+
# Initialize weights and apply final processing
|
798 |
+
self.post_init()
|
799 |
+
|
800 |
+
def set_decoder(self, decoder):
|
801 |
+
self.model = decoder
|
802 |
+
|
803 |
+
def get_decoder(self):
|
804 |
+
return self.model
|
805 |
+
|
806 |
+
@can_return_tuple
|
807 |
+
@auto_docstring
|
808 |
+
def forward(
|
809 |
+
self,
|
810 |
+
input_ids: Optional[torch.LongTensor] = None,
|
811 |
+
attention_mask: Optional[torch.Tensor] = None,
|
812 |
+
position_ids: Optional[torch.LongTensor] = None,
|
813 |
+
past_key_values: Optional[Cache] = None,
|
814 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
815 |
+
labels: Optional[torch.LongTensor] = None,
|
816 |
+
use_cache: Optional[bool] = None,
|
817 |
+
output_router_logits: Optional[bool] = None,
|
818 |
+
cache_position: Optional[torch.LongTensor] = None,
|
819 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
820 |
+
**kwargs: Unpack[TransformersKwargs],
|
821 |
+
) -> MoeCausalLMOutputWithPast:
|
822 |
+
r"""
|
823 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
824 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
825 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
826 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
827 |
+
|
828 |
+
Example:
|
829 |
+
|
830 |
+
```python
|
831 |
+
>>> from transformers import AutoTokenizer, GptOssForCausalLM
|
832 |
+
|
833 |
+
>>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1")
|
834 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1")
|
835 |
+
|
836 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
837 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
838 |
+
|
839 |
+
>>> # Generate
|
840 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
841 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
842 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
843 |
+
```"""
|
844 |
+
|
845 |
+
output_router_logits = (
|
846 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
847 |
+
)
|
848 |
+
|
849 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
850 |
+
outputs: MoeModelOutputWithPast = self.model(
|
851 |
+
input_ids=input_ids,
|
852 |
+
attention_mask=attention_mask,
|
853 |
+
position_ids=position_ids,
|
854 |
+
past_key_values=past_key_values,
|
855 |
+
inputs_embeds=inputs_embeds,
|
856 |
+
use_cache=use_cache,
|
857 |
+
output_router_logits=output_router_logits,
|
858 |
+
cache_position=cache_position,
|
859 |
+
**kwargs,
|
860 |
+
)
|
861 |
+
|
862 |
+
hidden_states = outputs.last_hidden_state
|
863 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
864 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
865 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
866 |
+
|
867 |
+
loss = None
|
868 |
+
if labels is not None:
|
869 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
870 |
+
|
871 |
+
aux_loss = None
|
872 |
+
if output_router_logits:
|
873 |
+
aux_loss = load_balancing_loss_func(
|
874 |
+
outputs.router_logits,
|
875 |
+
self.num_experts,
|
876 |
+
self.num_experts_per_tok,
|
877 |
+
attention_mask,
|
878 |
+
)
|
879 |
+
if labels is not None:
|
880 |
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
881 |
+
|
882 |
+
return MoeCausalLMOutputWithPast(
|
883 |
+
loss=loss,
|
884 |
+
aux_loss=aux_loss,
|
885 |
+
logits=logits,
|
886 |
+
past_key_values=outputs.past_key_values,
|
887 |
+
hidden_states=outputs.hidden_states,
|
888 |
+
attentions=outputs.attentions,
|
889 |
+
router_logits=outputs.router_logits,
|
890 |
+
)
|
891 |
+
|
892 |
+
|
893 |
+
__all__ = ["GptOssForCausalLM", "GptOssModel", "GptOssPreTrainedModel"]
|
quantization_config.json
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bits": 4,
|
3 |
+
"group_size": 128,
|
4 |
+
"sym": true,
|
5 |
+
"data_type": "int",
|
6 |
+
"nsamples": 512,
|
7 |
+
"autoround_version": "0.6.1.dev",
|
8 |
+
"quant_method": "auto-round",
|
9 |
+
"packing_format": "auto_round:auto_gptq",
|
10 |
+
"extra_config": {
|
11 |
+
"model.layers.0.self_attn.q_proj": {
|
12 |
+
"bits": 16
|
13 |
+
},
|
14 |
+
"model.layers.0.self_attn.k_proj": {
|
15 |
+
"bits": 16
|
16 |
+
},
|
17 |
+
"model.layers.0.self_attn.v_proj": {
|
18 |
+
"bits": 16
|
19 |
+
},
|
20 |
+
"model.layers.0.self_attn.o_proj": {
|
21 |
+
"bits": 16
|
22 |
+
},
|
23 |
+
"model.layers.0.mlp.router.router": {
|
24 |
+
"bits": 16
|
25 |
+
},
|
26 |
+
"model.layers.1.self_attn.q_proj": {
|
27 |
+
"bits": 16
|
28 |
+
},
|
29 |
+
"model.layers.1.self_attn.k_proj": {
|
30 |
+
"bits": 16
|
31 |
+
},
|
32 |
+
"model.layers.1.self_attn.v_proj": {
|
33 |
+
"bits": 16
|
34 |
+
},
|
35 |
+
"model.layers.1.self_attn.o_proj": {
|
36 |
+
"bits": 16
|
37 |
+
},
|
38 |
+
"model.layers.1.mlp.router.router": {
|
39 |
+
"bits": 16
|
40 |
+
},
|
41 |
+
"model.layers.2.self_attn.q_proj": {
|
42 |
+
"bits": 16
|
43 |
+
},
|
44 |
+
"model.layers.2.self_attn.k_proj": {
|
45 |
+
"bits": 16
|
46 |
+
},
|
47 |
+
"model.layers.2.self_attn.v_proj": {
|
48 |
+
"bits": 16
|
49 |
+
},
|
50 |
+
"model.layers.2.self_attn.o_proj": {
|
51 |
+
"bits": 16
|
52 |
+
},
|
53 |
+
"model.layers.2.mlp.router.router": {
|
54 |
+
"bits": 16
|
55 |
+
},
|
56 |
+
"model.layers.3.self_attn.q_proj": {
|
57 |
+
"bits": 16
|
58 |
+
},
|
59 |
+
"model.layers.3.self_attn.k_proj": {
|
60 |
+
"bits": 16
|
61 |
+
},
|
62 |
+
"model.layers.3.self_attn.v_proj": {
|
63 |
+
"bits": 16
|
64 |
+
},
|
65 |
+
"model.layers.3.self_attn.o_proj": {
|
66 |
+
"bits": 16
|
67 |
+
},
|
68 |
+
"model.layers.3.mlp.router.router": {
|
69 |
+
"bits": 16
|
70 |
+
},
|
71 |
+
"model.layers.4.self_attn.q_proj": {
|
72 |
+
"bits": 16
|
73 |
+
},
|
74 |
+
"model.layers.4.self_attn.k_proj": {
|
75 |
+
"bits": 16
|
76 |
+
},
|
77 |
+
"model.layers.4.self_attn.v_proj": {
|
78 |
+
"bits": 16
|
79 |
+
},
|
80 |
+
"model.layers.4.self_attn.o_proj": {
|
81 |
+
"bits": 16
|
82 |
+
},
|
83 |
+
"model.layers.4.mlp.router.router": {
|
84 |
+
"bits": 16
|
85 |
+
},
|
86 |
+
"model.layers.5.self_attn.q_proj": {
|
87 |
+
"bits": 16
|
88 |
+
},
|
89 |
+
"model.layers.5.self_attn.k_proj": {
|
90 |
+
"bits": 16
|
91 |
+
},
|
92 |
+
"model.layers.5.self_attn.v_proj": {
|
93 |
+
"bits": 16
|
94 |
+
},
|
95 |
+
"model.layers.5.self_attn.o_proj": {
|
96 |
+
"bits": 16
|
97 |
+
},
|
98 |
+
"model.layers.5.mlp.router.router": {
|
99 |
+
"bits": 16
|
100 |
+
},
|
101 |
+
"model.layers.6.self_attn.q_proj": {
|
102 |
+
"bits": 16
|
103 |
+
},
|
104 |
+
"model.layers.6.self_attn.k_proj": {
|
105 |
+
"bits": 16
|
106 |
+
},
|
107 |
+
"model.layers.6.self_attn.v_proj": {
|
108 |
+
"bits": 16
|
109 |
+
},
|
110 |
+
"model.layers.6.self_attn.o_proj": {
|
111 |
+
"bits": 16
|
112 |
+
},
|
113 |
+
"model.layers.6.mlp.router.router": {
|
114 |
+
"bits": 16
|
115 |
+
},
|
116 |
+
"model.layers.7.self_attn.q_proj": {
|
117 |
+
"bits": 16
|
118 |
+
},
|
119 |
+
"model.layers.7.self_attn.k_proj": {
|
120 |
+
"bits": 16
|
121 |
+
},
|
122 |
+
"model.layers.7.self_attn.v_proj": {
|
123 |
+
"bits": 16
|
124 |
+
},
|
125 |
+
"model.layers.7.self_attn.o_proj": {
|
126 |
+
"bits": 16
|
127 |
+
},
|
128 |
+
"model.layers.7.mlp.router.router": {
|
129 |
+
"bits": 16
|
130 |
+
},
|
131 |
+
"model.layers.8.self_attn.q_proj": {
|
132 |
+
"bits": 16
|
133 |
+
},
|
134 |
+
"model.layers.8.self_attn.k_proj": {
|
135 |
+
"bits": 16
|
136 |
+
},
|
137 |
+
"model.layers.8.self_attn.v_proj": {
|
138 |
+
"bits": 16
|
139 |
+
},
|
140 |
+
"model.layers.8.self_attn.o_proj": {
|
141 |
+
"bits": 16
|
142 |
+
},
|
143 |
+
"model.layers.8.mlp.router.router": {
|
144 |
+
"bits": 16
|
145 |
+
},
|
146 |
+
"model.layers.9.self_attn.q_proj": {
|
147 |
+
"bits": 16
|
148 |
+
},
|
149 |
+
"model.layers.9.self_attn.k_proj": {
|
150 |
+
"bits": 16
|
151 |
+
},
|
152 |
+
"model.layers.9.self_attn.v_proj": {
|
153 |
+
"bits": 16
|
154 |
+
},
|
155 |
+
"model.layers.9.self_attn.o_proj": {
|
156 |
+
"bits": 16
|
157 |
+
},
|
158 |
+
"model.layers.9.mlp.router.router": {
|
159 |
+
"bits": 16
|
160 |
+
},
|
161 |
+
"model.layers.10.self_attn.q_proj": {
|
162 |
+
"bits": 16
|
163 |
+
},
|
164 |
+
"model.layers.10.self_attn.k_proj": {
|
165 |
+
"bits": 16
|
166 |
+
},
|
167 |
+
"model.layers.10.self_attn.v_proj": {
|
168 |
+
"bits": 16
|
169 |
+
},
|
170 |
+
"model.layers.10.self_attn.o_proj": {
|
171 |
+
"bits": 16
|
172 |
+
},
|
173 |
+
"model.layers.10.mlp.router.router": {
|
174 |
+
"bits": 16
|
175 |
+
},
|
176 |
+
"model.layers.11.self_attn.q_proj": {
|
177 |
+
"bits": 16
|
178 |
+
},
|
179 |
+
"model.layers.11.self_attn.k_proj": {
|
180 |
+
"bits": 16
|
181 |
+
},
|
182 |
+
"model.layers.11.self_attn.v_proj": {
|
183 |
+
"bits": 16
|
184 |
+
},
|
185 |
+
"model.layers.11.self_attn.o_proj": {
|
186 |
+
"bits": 16
|
187 |
+
},
|
188 |
+
"model.layers.11.mlp.router.router": {
|
189 |
+
"bits": 16
|
190 |
+
},
|
191 |
+
"model.layers.12.self_attn.q_proj": {
|
192 |
+
"bits": 16
|
193 |
+
},
|
194 |
+
"model.layers.12.self_attn.k_proj": {
|
195 |
+
"bits": 16
|
196 |
+
},
|
197 |
+
"model.layers.12.self_attn.v_proj": {
|
198 |
+
"bits": 16
|
199 |
+
},
|
200 |
+
"model.layers.12.self_attn.o_proj": {
|
201 |
+
"bits": 16
|
202 |
+
},
|
203 |
+
"model.layers.12.mlp.router.router": {
|
204 |
+
"bits": 16
|
205 |
+
},
|
206 |
+
"model.layers.13.self_attn.q_proj": {
|
207 |
+
"bits": 16
|
208 |
+
},
|
209 |
+
"model.layers.13.self_attn.k_proj": {
|
210 |
+
"bits": 16
|
211 |
+
},
|
212 |
+
"model.layers.13.self_attn.v_proj": {
|
213 |
+
"bits": 16
|
214 |
+
},
|
215 |
+
"model.layers.13.self_attn.o_proj": {
|
216 |
+
"bits": 16
|
217 |
+
},
|
218 |
+
"model.layers.13.mlp.router.router": {
|
219 |
+
"bits": 16
|
220 |
+
},
|
221 |
+
"model.layers.14.self_attn.q_proj": {
|
222 |
+
"bits": 16
|
223 |
+
},
|
224 |
+
"model.layers.14.self_attn.k_proj": {
|
225 |
+
"bits": 16
|
226 |
+
},
|
227 |
+
"model.layers.14.self_attn.v_proj": {
|
228 |
+
"bits": 16
|
229 |
+
},
|
230 |
+
"model.layers.14.self_attn.o_proj": {
|
231 |
+
"bits": 16
|
232 |
+
},
|
233 |
+
"model.layers.14.mlp.router.router": {
|
234 |
+
"bits": 16
|
235 |
+
},
|
236 |
+
"model.layers.15.self_attn.q_proj": {
|
237 |
+
"bits": 16
|
238 |
+
},
|
239 |
+
"model.layers.15.self_attn.k_proj": {
|
240 |
+
"bits": 16
|
241 |
+
},
|
242 |
+
"model.layers.15.self_attn.v_proj": {
|
243 |
+
"bits": 16
|
244 |
+
},
|
245 |
+
"model.layers.15.self_attn.o_proj": {
|
246 |
+
"bits": 16
|
247 |
+
},
|
248 |
+
"model.layers.15.mlp.router.router": {
|
249 |
+
"bits": 16
|
250 |
+
},
|
251 |
+
"model.layers.16.self_attn.q_proj": {
|
252 |
+
"bits": 16
|
253 |
+
},
|
254 |
+
"model.layers.16.self_attn.k_proj": {
|
255 |
+
"bits": 16
|
256 |
+
},
|
257 |
+
"model.layers.16.self_attn.v_proj": {
|
258 |
+
"bits": 16
|
259 |
+
},
|
260 |
+
"model.layers.16.self_attn.o_proj": {
|
261 |
+
"bits": 16
|
262 |
+
},
|
263 |
+
"model.layers.16.mlp.router.router": {
|
264 |
+
"bits": 16
|
265 |
+
},
|
266 |
+
"model.layers.17.self_attn.q_proj": {
|
267 |
+
"bits": 16
|
268 |
+
},
|
269 |
+
"model.layers.17.self_attn.k_proj": {
|
270 |
+
"bits": 16
|
271 |
+
},
|
272 |
+
"model.layers.17.self_attn.v_proj": {
|
273 |
+
"bits": 16
|
274 |
+
},
|
275 |
+
"model.layers.17.self_attn.o_proj": {
|
276 |
+
"bits": 16
|
277 |
+
},
|
278 |
+
"model.layers.17.mlp.router.router": {
|
279 |
+
"bits": 16
|
280 |
+
},
|
281 |
+
"model.layers.18.self_attn.q_proj": {
|
282 |
+
"bits": 16
|
283 |
+
},
|
284 |
+
"model.layers.18.self_attn.k_proj": {
|
285 |
+
"bits": 16
|
286 |
+
},
|
287 |
+
"model.layers.18.self_attn.v_proj": {
|
288 |
+
"bits": 16
|
289 |
+
},
|
290 |
+
"model.layers.18.self_attn.o_proj": {
|
291 |
+
"bits": 16
|
292 |
+
},
|
293 |
+
"model.layers.18.mlp.router.router": {
|
294 |
+
"bits": 16
|
295 |
+
},
|
296 |
+
"model.layers.19.self_attn.q_proj": {
|
297 |
+
"bits": 16
|
298 |
+
},
|
299 |
+
"model.layers.19.self_attn.k_proj": {
|
300 |
+
"bits": 16
|
301 |
+
},
|
302 |
+
"model.layers.19.self_attn.v_proj": {
|
303 |
+
"bits": 16
|
304 |
+
},
|
305 |
+
"model.layers.19.self_attn.o_proj": {
|
306 |
+
"bits": 16
|
307 |
+
},
|
308 |
+
"model.layers.19.mlp.router.router": {
|
309 |
+
"bits": 16
|
310 |
+
},
|
311 |
+
"model.layers.20.self_attn.q_proj": {
|
312 |
+
"bits": 16
|
313 |
+
},
|
314 |
+
"model.layers.20.self_attn.k_proj": {
|
315 |
+
"bits": 16
|
316 |
+
},
|
317 |
+
"model.layers.20.self_attn.v_proj": {
|
318 |
+
"bits": 16
|
319 |
+
},
|
320 |
+
"model.layers.20.self_attn.o_proj": {
|
321 |
+
"bits": 16
|
322 |
+
},
|
323 |
+
"model.layers.20.mlp.router.router": {
|
324 |
+
"bits": 16
|
325 |
+
},
|
326 |
+
"model.layers.21.self_attn.q_proj": {
|
327 |
+
"bits": 16
|
328 |
+
},
|
329 |
+
"model.layers.21.self_attn.k_proj": {
|
330 |
+
"bits": 16
|
331 |
+
},
|
332 |
+
"model.layers.21.self_attn.v_proj": {
|
333 |
+
"bits": 16
|
334 |
+
},
|
335 |
+
"model.layers.21.self_attn.o_proj": {
|
336 |
+
"bits": 16
|
337 |
+
},
|
338 |
+
"model.layers.21.mlp.router.router": {
|
339 |
+
"bits": 16
|
340 |
+
},
|
341 |
+
"model.layers.22.self_attn.q_proj": {
|
342 |
+
"bits": 16
|
343 |
+
},
|
344 |
+
"model.layers.22.self_attn.k_proj": {
|
345 |
+
"bits": 16
|
346 |
+
},
|
347 |
+
"model.layers.22.self_attn.v_proj": {
|
348 |
+
"bits": 16
|
349 |
+
},
|
350 |
+
"model.layers.22.self_attn.o_proj": {
|
351 |
+
"bits": 16
|
352 |
+
},
|
353 |
+
"model.layers.22.mlp.router.router": {
|
354 |
+
"bits": 16
|
355 |
+
},
|
356 |
+
"model.layers.23.self_attn.q_proj": {
|
357 |
+
"bits": 16
|
358 |
+
},
|
359 |
+
"model.layers.23.self_attn.k_proj": {
|
360 |
+
"bits": 16
|
361 |
+
},
|
362 |
+
"model.layers.23.self_attn.v_proj": {
|
363 |
+
"bits": 16
|
364 |
+
},
|
365 |
+
"model.layers.23.self_attn.o_proj": {
|
366 |
+
"bits": 16
|
367 |
+
},
|
368 |
+
"model.layers.23.mlp.router.router": {
|
369 |
+
"bits": 16
|
370 |
+
}
|
371 |
+
}
|
372 |
+
}
|
special_tokens_map.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<|startoftext|>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "<|return|>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": {
|
17 |
+
"content": "<|endoftext|>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
}
|
23 |
+
}
|
tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e0ca2e99eca05c8a688ec60100806dda193defc5839c985b321d9e8492efcb84
|
3 |
+
size 27868273
|
tokenizer_config.json
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"199998": {
|
4 |
+
"content": "<|startoftext|>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"199999": {
|
12 |
+
"content": "<|endoftext|>",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"200000": {
|
20 |
+
"content": "<|reserved_200000|>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"200001": {
|
28 |
+
"content": "<|reserved_200001|>",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"200002": {
|
36 |
+
"content": "<|return|>",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
},
|
43 |
+
"200003": {
|
44 |
+
"content": "<|constrain|>",
|
45 |
+
"lstrip": false,
|
46 |
+
"normalized": false,
|
47 |
+
"rstrip": false,
|
48 |
+
"single_word": false,
|
49 |
+
"special": true
|
50 |
+
},
|
51 |
+
"200004": {
|
52 |
+
"content": "<|reserved_200004|>",
|
53 |
+
"lstrip": false,
|
54 |
+
"normalized": false,
|
55 |
+
"rstrip": false,
|
56 |
+
"single_word": false,
|
57 |
+
"special": true
|
58 |
+
},
|
59 |
+
"200005": {
|
60 |
+
"content": "<|channel|>",
|
61 |
+
"lstrip": false,
|
62 |
+
"normalized": false,
|
63 |
+
"rstrip": false,
|
64 |
+
"single_word": false,
|
65 |
+
"special": true
|
66 |
+
},
|
67 |
+
"200006": {
|
68 |
+
"content": "<|start|>",
|
69 |
+
"lstrip": false,
|
70 |
+
"normalized": false,
|
71 |
+
"rstrip": false,
|
72 |
+
"single_word": false,
|
73 |
+
"special": true
|
74 |
+
},
|
75 |
+
"200007": {
|
76 |
+
"content": "<|end|>",
|
77 |
+
"lstrip": false,
|
78 |
+
"normalized": false,
|
79 |
+
"rstrip": false,
|
80 |
+
"single_word": false,
|
81 |
+
"special": true
|
82 |
+
},
|
83 |
+
"200008": {
|
84 |
+
"content": "<|message|>",
|
85 |
+
"lstrip": false,
|
86 |
+
"normalized": false,
|
87 |
+
"rstrip": false,
|
88 |
+
"single_word": false,
|
89 |
+
"special": true
|
90 |
+
},
|
91 |
+
"200009": {
|
92 |
+
"content": "<|reserved_200009|>",
|
93 |
+
"lstrip": false,
|
94 |
+
"normalized": false,
|
95 |
+
"rstrip": false,
|
96 |
+
"single_word": false,
|
97 |
+
"special": true
|
98 |
+
},
|
99 |
+
"200010": {
|
100 |
+
"content": "<|reserved_200010|>",
|
101 |
+
"lstrip": false,
|
102 |
+
"normalized": false,
|
103 |
+
"rstrip": false,
|
104 |
+
"single_word": false,
|
105 |
+
"special": true
|
106 |
+
},
|
107 |
+
"200011": {
|
108 |
+
"content": "<|reserved_200011|>",
|
109 |
+
"lstrip": false,
|
110 |
+
"normalized": false,
|
111 |
+
"rstrip": false,
|
112 |
+
"single_word": false,
|
113 |
+
"special": true
|
114 |
+
},
|
115 |
+
"200012": {
|
116 |
+
"content": "<|call|>",
|
117 |
+
"lstrip": false,
|
118 |
+
"normalized": false,
|
119 |
+
"rstrip": false,
|
120 |
+
"single_word": false,
|
121 |
+
"special": true
|
122 |
+
},
|
123 |
+
"200013": {
|
124 |
+
"content": "<|reserved_200013|>",
|
125 |
+
"lstrip": false,
|
126 |
+
"normalized": false,
|
127 |
+
"rstrip": false,
|
128 |
+
"single_word": false,
|
129 |
+
"special": true
|
130 |
+
},
|
131 |
+
"200014": {
|
132 |
+
"content": "<|reserved_200014|>",
|
133 |
+
"lstrip": false,
|
134 |
+
"normalized": false,
|
135 |
+
"rstrip": false,
|
136 |
+
"single_word": false,
|
137 |
+
"special": true
|
138 |
+
},
|
139 |
+
"200015": {
|
140 |
+
"content": "<|reserved_200015|>",
|
141 |
+
"lstrip": false,
|
142 |
+
"normalized": false,
|
143 |
+
"rstrip": false,
|
144 |
+
"single_word": false,
|
145 |
+
"special": true
|
146 |
+
},
|
147 |
+
"200016": {
|
148 |
+
"content": "<|reserved_200016|>",
|
149 |
+
"lstrip": false,
|
150 |
+
"normalized": false,
|
151 |
+
"rstrip": false,
|
152 |
+
"single_word": false,
|
153 |
+
"special": true
|
154 |
+
},
|
155 |
+
"200017": {
|
156 |
+
"content": "<|reserved_200017|>",
|
157 |
+
"lstrip": false,
|
158 |
+
"normalized": false,
|
159 |
+
"rstrip": false,
|
160 |
+
"single_word": false,
|
161 |
+
"special": true
|
162 |
+
},
|
163 |
+
"200018": {
|
164 |
+
"content": "<|endofprompt|>",
|
165 |
+
"lstrip": false,
|
166 |
+
"normalized": false,
|
167 |
+
"rstrip": false,
|
168 |
+
"single_word": false,
|
169 |
+
"special": true
|
170 |
+
}
|
171 |
+
},
|
172 |
+
"bos_token": "<|startoftext|>",
|
173 |
+
"clean_up_tokenization_spaces": false,
|
174 |
+
"eos_token": "<|return|>",
|
175 |
+
"extra_special_tokens": {},
|
176 |
+
"model_input_names": [
|
177 |
+
"input_ids",
|
178 |
+
"attention_mask"
|
179 |
+
],
|
180 |
+
"model_max_length": 1000000000000000019884624838656,
|
181 |
+
"pad_token": "<|endoftext|>",
|
182 |
+
"tokenizer_class": "PreTrainedTokenizerFast"
|
183 |
+
}
|