fix missing the `{% generation %}` keyword while using tokenizer.apply_chat_template(...return_assistant_tokens_mask=True)
Browse files```python
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("/opt/tiger/gpt-oss-20b")
messages = [
{
"role": "user",
"content": "hi"
},
{
"role": "assistant",
"thinking": "think a moment",
"content": "Hello"
}
]
print(tokenizer.apply_chat_template(messages, tokenize=False).split('<|end|>', 1)[1])
processed = tokenizer.apply_chat_template(
messages,
reasoning_effort="high",
return_assistant_tokens_mask=True,
return_dict=True,)
first_end = processed["input_ids"].index(200007) + 1
print(processed['input_ids'][first_end:])
print(processed['attention_mask'][first_end:])
print(processed['assistant_masks'][first_end:])
```
Original Output:
```plain
return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword.
<|start|>user<|message|>hi<|end|><|start|>assistant<|channel|>analysis<|message|>think a moment<|end|><|start|>assistant<|channel|>final<|message|>Hello<|return|>
[200006, 1428, 200008, 3686, 200007, 200006, 173781, 200005, 35644, 200008, 49631, 261, 4205, 200007, 200006, 173781, 200005, 17196, 200008, 13225, 200002]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
```
Output after fixing:
```plain
<|start|>user<|message|>hi<|end|><|start|>assistant<|channel|>analysis<|message|>think a moment<|end|><|start|>assistant<|channel|>final<|message|>Hello<|return|>
[200006, 1428, 200008, 3686, 200007, 200006, 173781, 200005, 35644, 200008, 49631, 261, 4205, 200007, 200006, 173781, 200005, 17196, 200008, 13225, 200002]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
```
- chat_template.jinja +2 -0
@@ -259,6 +259,7 @@
|
|
259 |
{%- for message in loop_messages -%}
|
260 |
{#- At this point only assistant/user/tool messages should remain #}
|
261 |
{%- if message.role == 'assistant' -%}
|
|
|
262 |
{#- Checks to ensure the messages are being passed in the format we expect #}
|
263 |
{%- if "content" in message %}
|
264 |
{%- if "<|channel|>analysis<|message|>" in message.content or "<|channel|>final<|message|>" in message.content %}
|
@@ -314,6 +315,7 @@
|
|
314 |
{{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }}
|
315 |
{%- set last_tool_call.name = none %}
|
316 |
{%- endif %}
|
|
|
317 |
{%- elif message.role == 'tool' -%}
|
318 |
{%- if last_tool_call.name is none %}
|
319 |
{{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }}
|
|
|
259 |
{%- for message in loop_messages -%}
|
260 |
{#- At this point only assistant/user/tool messages should remain #}
|
261 |
{%- if message.role == 'assistant' -%}
|
262 |
+
{% generation %}
|
263 |
{#- Checks to ensure the messages are being passed in the format we expect #}
|
264 |
{%- if "content" in message %}
|
265 |
{%- if "<|channel|>analysis<|message|>" in message.content or "<|channel|>final<|message|>" in message.content %}
|
|
|
315 |
{{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }}
|
316 |
{%- set last_tool_call.name = none %}
|
317 |
{%- endif %}
|
318 |
+
{% endgeneration %}
|
319 |
{%- elif message.role == 'tool' -%}
|
320 |
{%- if last_tool_call.name is none %}
|
321 |
{{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }}
|