H commited on
Commit
3365e9f
·
verified ·
1 Parent(s): bc75b44

fix missing the `{% generation %}` keyword while using tokenizer.apply_chat_template(...return_assistant_tokens_mask=True)

Browse files

```python
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("/opt/tiger/gpt-oss-20b")

messages = [
{
"role": "user",
"content": "hi"
},
{
"role": "assistant",
"thinking": "think a moment",
"content": "Hello"
}
]

print(tokenizer.apply_chat_template(messages, tokenize=False).split('<|end|>', 1)[1])

processed = tokenizer.apply_chat_template(
messages,
reasoning_effort="high",
return_assistant_tokens_mask=True,
return_dict=True,)

first_end = processed["input_ids"].index(200007) + 1
print(processed['input_ids'][first_end:])
print(processed['attention_mask'][first_end:])
print(processed['assistant_masks'][first_end:])
```
Original Output:
```plain
return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword.
<|start|>user<|message|>hi<|end|><|start|>assistant<|channel|>analysis<|message|>think a moment<|end|><|start|>assistant<|channel|>final<|message|>Hello<|return|>
[200006, 1428, 200008, 3686, 200007, 200006, 173781, 200005, 35644, 200008, 49631, 261, 4205, 200007, 200006, 173781, 200005, 17196, 200008, 13225, 200002]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
```
Output after fixing:
```plain
<|start|>user<|message|>hi<|end|><|start|>assistant<|channel|>analysis<|message|>think a moment<|end|><|start|>assistant<|channel|>final<|message|>Hello<|return|>
[200006, 1428, 200008, 3686, 200007, 200006, 173781, 200005, 35644, 200008, 49631, 261, 4205, 200007, 200006, 173781, 200005, 17196, 200008, 13225, 200002]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
```

Files changed (1) hide show
  1. chat_template.jinja +2 -0
chat_template.jinja CHANGED
@@ -259,6 +259,7 @@
259
  {%- for message in loop_messages -%}
260
  {#- At this point only assistant/user/tool messages should remain #}
261
  {%- if message.role == 'assistant' -%}
 
262
  {#- Checks to ensure the messages are being passed in the format we expect #}
263
  {%- if "content" in message %}
264
  {%- if "<|channel|>analysis<|message|>" in message.content or "<|channel|>final<|message|>" in message.content %}
@@ -314,6 +315,7 @@
314
  {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }}
315
  {%- set last_tool_call.name = none %}
316
  {%- endif %}
 
317
  {%- elif message.role == 'tool' -%}
318
  {%- if last_tool_call.name is none %}
319
  {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }}
 
259
  {%- for message in loop_messages -%}
260
  {#- At this point only assistant/user/tool messages should remain #}
261
  {%- if message.role == 'assistant' -%}
262
+ {% generation %}
263
  {#- Checks to ensure the messages are being passed in the format we expect #}
264
  {%- if "content" in message %}
265
  {%- if "<|channel|>analysis<|message|>" in message.content or "<|channel|>final<|message|>" in message.content %}
 
315
  {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }}
316
  {%- set last_tool_call.name = none %}
317
  {%- endif %}
318
+ {% endgeneration %}
319
  {%- elif message.role == 'tool' -%}
320
  {%- if last_tool_call.name is none %}
321
  {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }}