Wrap assistant messages inside "{% generation %}" markers in chat_template.jinja
Browse filesAdded "{% generation %}" markers enables the TRL [SFTTrainer](https://huggingface.co/docs/trl/main/en/sft_trainer#trl.SFTConfig)'s `assistant_only_loss` config option. `assistant_only_loss` tells the SFTTrainer to only enable gradients on the assistant messages, which are wrapped around `{% generation %}` by this PR. I confirmed that this behaves as expected by using this custom template for the `gpt-oss-20b` tokenizer as the `processing_class` for SFTTrainer.
See this transformers [PR](https://github.com/huggingface/transformers/pull/30650) that introduced this change
See also how [trl/trainer/sft_trainer.py](https://github.com/huggingface/trl/blob/206964ce16e15f2afd4f8f12fe49d1d828312f97/trl/trainer/sft_trainer.py#L845) uses this marker in [transformers/utils/chat_template_utils.py](https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/utils/chat_template_utils.py#L475).
Code segment to verify the masking is done correctly, where assistant tokens are printed in green:
```
tokenizer = AutoTokenizer.from_pretrained('openai/gpt-oss-20b', trust_remote_code=True)
tokenizer.chat_template = CORRECTED_JINJA_TEMPLATE
templated_output = tokenizer.apply_chat_template(
sample['messages'],
tokenize=True,
add_generation_prompt=False,
return_assistant_tokens_mask=True,
return_dict=True,
)
print("Visualizing token masks. Green text is used for loss calculation.\n")
GREEN = "\033[92m"
RESET = "\033[0m"
input_ids = templated_output['input_ids']
assistant_mask = templated_output['assistant_masks']
if len(input_ids) != len(assistant_mask):
raise ValueError("Mismatch between input_ids and assistant_masks length.")
current_chunk_tokens = []
current_mask_status = None
for token_id, is_assistant in zip(input_ids, assistant_mask):
mask_status = bool(is_assistant)
if current_mask_status is None:
current_mask_status = mask_status
if mask_status != current_mask_status:
# Decode and print the completed chunk
decoded_text = tokenizer.decode(current_chunk_tokens, skip_special_tokens=False)
if current_mask_status:
print(f"{GREEN}{decoded_text}{RESET}", end="")
else:
print(decoded_text, end="")
# Start a new chunk
current_chunk_tokens = [token_id]
current_mask_status = mask_status
else:
current_chunk_tokens.append(token_id)
# Print the final chunk after the loop
if current_chunk_tokens:
decoded_text = tokenizer.decode(current_chunk_tokens, skip_special_tokens=False)
if current_mask_status:
print(f"{GREEN}{decoded_text}{RESET}", end="")
else:
print(decoded_text, end="")
```
Prints something like:
```
<|start|>user<|message|>USER_MESSAGE<|end|>[GREEN_STARTS]<|start|>assistant<|channel|>analysis<|message|>...<|call|>[GREEN_ENDS]
```
- chat_template.jinja +24 -17
@@ -288,30 +288,37 @@
|
|
288 |
{%- endif %}
|
289 |
{%- if message.content and message.thinking %}
|
290 |
{{- raise_exception("Cannot pass both content and thinking in an assistant message with tool calls! Put the analysis message in one or the other, but not both.") }}
|
291 |
-
{%- elif message.content and not future_final_message.found %}
|
292 |
-
{{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }}
|
293 |
-
{%- elif message.thinking and not future_final_message.found %}
|
294 |
-
{{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
|
295 |
{%- endif %}
|
296 |
-
{
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
{%- set last_tool_call.name = tool_call.name %}
|
302 |
{%- elif loop.last and not add_generation_prompt %}
|
303 |
{#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #}
|
304 |
{#- This is a situation that should only occur in training, never in inference. #}
|
305 |
-
{
|
306 |
-
{
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
|
|
|
|
312 |
{%- else %}
|
313 |
{#- CoT is dropped during all previous turns, so we never render it for inference #}
|
314 |
-
{
|
|
|
|
|
315 |
{%- set last_tool_call.name = none %}
|
316 |
{%- endif %}
|
317 |
{%- elif message.role == 'tool' -%}
|
|
|
288 |
{%- endif %}
|
289 |
{%- if message.content and message.thinking %}
|
290 |
{{- raise_exception("Cannot pass both content and thinking in an assistant message with tool calls! Put the analysis message in one or the other, but not both.") }}
|
|
|
|
|
|
|
|
|
291 |
{%- endif %}
|
292 |
+
{% generation %}
|
293 |
+
{%- if message.content and not future_final_message.found %}
|
294 |
+
{{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }}
|
295 |
+
{%- elif message.thinking and not future_final_message.found %}
|
296 |
+
{{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
|
297 |
+
{%- endif %}
|
298 |
+
{{- "<|start|>assistant to=" }}
|
299 |
+
{{- "functions." + tool_call.name + "<|channel|>commentary " }}
|
300 |
+
{{- (tool_call.content_type if tool_call.content_type is defined else "json") + "<|message|>" }}
|
301 |
+
{{- tool_call.arguments|tojson }}
|
302 |
+
{{- "<|call|>" }}
|
303 |
+
{% endgeneration %}
|
304 |
{%- set last_tool_call.name = tool_call.name %}
|
305 |
{%- elif loop.last and not add_generation_prompt %}
|
306 |
{#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #}
|
307 |
{#- This is a situation that should only occur in training, never in inference. #}
|
308 |
+
{% generation %}
|
309 |
+
{%- if "thinking" in message %}
|
310 |
+
{{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
|
311 |
+
{%- endif %}
|
312 |
+
{#- <|return|> indicates the end of generation, but <|end|> does not #}
|
313 |
+
{#- <|return|> should never be an input to the model, but we include it as the final token #}
|
314 |
+
{#- when training, so the model learns to emit it. #}
|
315 |
+
{{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|return|>" }}
|
316 |
+
{% endgeneration %}
|
317 |
{%- else %}
|
318 |
{#- CoT is dropped during all previous turns, so we never render it for inference #}
|
319 |
+
{% generation %}
|
320 |
+
{{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }}
|
321 |
+
{% endgeneration %}
|
322 |
{%- set last_tool_call.name = none %}
|
323 |
{%- endif %}
|
324 |
{%- elif message.role == 'tool' -%}
|