daohanlu commited on
Commit
b92455e
·
verified ·
1 Parent(s): d666cf3

Wrap assistant messages inside "{% generation %}" markers in chat_template.jinja

Browse files

Added "{% generation %}" markers enables the TRL [SFTTrainer](https://huggingface.co/docs/trl/main/en/sft_trainer#trl.SFTConfig)'s `assistant_only_loss` config option. `assistant_only_loss` tells the SFTTrainer to only enable gradients on the assistant messages, which are wrapped around `{% generation %}` by this PR. I confirmed that this behaves as expected by using this custom template for the `gpt-oss-20b` tokenizer as the `processing_class` for SFTTrainer.

See this transformers [PR](https://github.com/huggingface/transformers/pull/30650) that introduced this change
See also how [trl/trainer/sft_trainer.py](https://github.com/huggingface/trl/blob/206964ce16e15f2afd4f8f12fe49d1d828312f97/trl/trainer/sft_trainer.py#L845) uses this marker in [transformers/utils/chat_template_utils.py](https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/utils/chat_template_utils.py#L475).

Code segment to verify the masking is done correctly, where assistant tokens are printed in green:
```
tokenizer = AutoTokenizer.from_pretrained('openai/gpt-oss-20b', trust_remote_code=True)
tokenizer.chat_template = CORRECTED_JINJA_TEMPLATE

templated_output = tokenizer.apply_chat_template(
sample['messages'],
tokenize=True,
add_generation_prompt=False,
return_assistant_tokens_mask=True,
return_dict=True,
)

print("Visualizing token masks. Green text is used for loss calculation.\n")
GREEN = "\033[92m"
RESET = "\033[0m"

input_ids = templated_output['input_ids']
assistant_mask = templated_output['assistant_masks']

if len(input_ids) != len(assistant_mask):
raise ValueError("Mismatch between input_ids and assistant_masks length.")

current_chunk_tokens = []
current_mask_status = None

for token_id, is_assistant in zip(input_ids, assistant_mask):
mask_status = bool(is_assistant)
if current_mask_status is None:
current_mask_status = mask_status

if mask_status != current_mask_status:
# Decode and print the completed chunk
decoded_text = tokenizer.decode(current_chunk_tokens, skip_special_tokens=False)
if current_mask_status:
print(f"{GREEN}{decoded_text}{RESET}", end="")
else:
print(decoded_text, end="")

# Start a new chunk
current_chunk_tokens = [token_id]
current_mask_status = mask_status
else:
current_chunk_tokens.append(token_id)

# Print the final chunk after the loop
if current_chunk_tokens:
decoded_text = tokenizer.decode(current_chunk_tokens, skip_special_tokens=False)
if current_mask_status:
print(f"{GREEN}{decoded_text}{RESET}", end="")
else:
print(decoded_text, end="")
```

Prints something like:
```
<|start|>user<|message|>USER_MESSAGE<|end|>[GREEN_STARTS]<|start|>assistant<|channel|>analysis<|message|>...<|call|>[GREEN_ENDS]
```

Files changed (1) hide show
  1. chat_template.jinja +24 -17
chat_template.jinja CHANGED
@@ -288,30 +288,37 @@
288
  {%- endif %}
289
  {%- if message.content and message.thinking %}
290
  {{- raise_exception("Cannot pass both content and thinking in an assistant message with tool calls! Put the analysis message in one or the other, but not both.") }}
291
- {%- elif message.content and not future_final_message.found %}
292
- {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }}
293
- {%- elif message.thinking and not future_final_message.found %}
294
- {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
295
  {%- endif %}
296
- {{- "<|start|>assistant to=" }}
297
- {{- "functions." + tool_call.name + "<|channel|>commentary " }}
298
- {{- (tool_call.content_type if tool_call.content_type is defined else "json") + "<|message|>" }}
299
- {{- tool_call.arguments|tojson }}
300
- {{- "<|call|>" }}
 
 
 
 
 
 
 
301
  {%- set last_tool_call.name = tool_call.name %}
302
  {%- elif loop.last and not add_generation_prompt %}
303
  {#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #}
304
  {#- This is a situation that should only occur in training, never in inference. #}
305
- {%- if "thinking" in message %}
306
- {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
307
- {%- endif %}
308
- {#- <|return|> indicates the end of generation, but <|end|> does not #}
309
- {#- <|return|> should never be an input to the model, but we include it as the final token #}
310
- {#- when training, so the model learns to emit it. #}
311
- {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|return|>" }}
 
 
312
  {%- else %}
313
  {#- CoT is dropped during all previous turns, so we never render it for inference #}
314
- {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }}
 
 
315
  {%- set last_tool_call.name = none %}
316
  {%- endif %}
317
  {%- elif message.role == 'tool' -%}
 
288
  {%- endif %}
289
  {%- if message.content and message.thinking %}
290
  {{- raise_exception("Cannot pass both content and thinking in an assistant message with tool calls! Put the analysis message in one or the other, but not both.") }}
 
 
 
 
291
  {%- endif %}
292
+ {% generation %}
293
+ {%- if message.content and not future_final_message.found %}
294
+ {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }}
295
+ {%- elif message.thinking and not future_final_message.found %}
296
+ {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
297
+ {%- endif %}
298
+ {{- "<|start|>assistant to=" }}
299
+ {{- "functions." + tool_call.name + "<|channel|>commentary " }}
300
+ {{- (tool_call.content_type if tool_call.content_type is defined else "json") + "<|message|>" }}
301
+ {{- tool_call.arguments|tojson }}
302
+ {{- "<|call|>" }}
303
+ {% endgeneration %}
304
  {%- set last_tool_call.name = tool_call.name %}
305
  {%- elif loop.last and not add_generation_prompt %}
306
  {#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #}
307
  {#- This is a situation that should only occur in training, never in inference. #}
308
+ {% generation %}
309
+ {%- if "thinking" in message %}
310
+ {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
311
+ {%- endif %}
312
+ {#- <|return|> indicates the end of generation, but <|end|> does not #}
313
+ {#- <|return|> should never be an input to the model, but we include it as the final token #}
314
+ {#- when training, so the model learns to emit it. #}
315
+ {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|return|>" }}
316
+ {% endgeneration %}
317
  {%- else %}
318
  {#- CoT is dropped during all previous turns, so we never render it for inference #}
319
+ {% generation %}
320
+ {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }}
321
+ {% endgeneration %}
322
  {%- set last_tool_call.name = none %}
323
  {%- endif %}
324
  {%- elif message.role == 'tool' -%}