Wrap assistant messages inside "{% generation %}" markers in chat_template.jinja
#126
by
daohanlu
- opened
Added "{% generation %}" markers enables the TRL SFTTrainer's assistant_only_loss
config option. assistant_only_loss
tells the SFTTrainer to only enable gradients on the assistant messages, which are wrapped around {% generation %}
by this PR. I confirmed that this behaves as expected by using this custom template for the gpt-oss-20b
tokenizer as the processing_class
for SFTTrainer.
See this transformers PR that introduced this change
See also how trl/trainer/sft_trainer.py uses this marker in transformers/utils/chat_template_utils.py.
Code segment to verify the masking is done correctly, where assistant tokens are printed in green:
tokenizer = AutoTokenizer.from_pretrained('openai/gpt-oss-20b', trust_remote_code=True)
tokenizer.chat_template = CORRECTED_JINJA_TEMPLATE
templated_output = tokenizer.apply_chat_template(
sample['messages'],
tokenize=True,
add_generation_prompt=False,
return_assistant_tokens_mask=True,
return_dict=True,
)
print("Visualizing token masks. Green text is used for loss calculation.\n")
GREEN = "\033[92m"
RESET = "\033[0m"
input_ids = templated_output['input_ids']
assistant_mask = templated_output['assistant_masks']
if len(input_ids) != len(assistant_mask):
raise ValueError("Mismatch between input_ids and assistant_masks length.")
current_chunk_tokens = []
current_mask_status = None
for token_id, is_assistant in zip(input_ids, assistant_mask):
mask_status = bool(is_assistant)
if current_mask_status is None:
current_mask_status = mask_status
if mask_status != current_mask_status:
# Decode and print the completed chunk
decoded_text = tokenizer.decode(current_chunk_tokens, skip_special_tokens=False)
if current_mask_status:
print(f"{GREEN}{decoded_text}{RESET}", end="")
else:
print(decoded_text, end="")
# Start a new chunk
current_chunk_tokens = [token_id]
current_mask_status = mask_status
else:
current_chunk_tokens.append(token_id)
# Print the final chunk after the loop
if current_chunk_tokens:
decoded_text = tokenizer.decode(current_chunk_tokens, skip_special_tokens=False)
if current_mask_status:
print(f"{GREEN}{decoded_text}{RESET}", end="")
else:
print(decoded_text, end="")
Prints something like:
<|start|>user<|message|>USER_MESSAGE<|end|>[GREEN_STARTS]<|start|>assistant<|channel|>analysis<|message|>...<|call|>[GREEN_ENDS]