Wrap assistant messages inside "{% generation %}" markers in chat_template.jinja

#126

Added "{% generation %}" markers enables the TRL SFTTrainer's assistant_only_loss config option. assistant_only_loss tells the SFTTrainer to only enable gradients on the assistant messages, which are wrapped around {% generation %} by this PR. I confirmed that this behaves as expected by using this custom template for the gpt-oss-20b tokenizer as the processing_class for SFTTrainer.

See this transformers PR that introduced this change
See also how trl/trainer/sft_trainer.py uses this marker in transformers/utils/chat_template_utils.py.

Code segment to verify the masking is done correctly, where assistant tokens are printed in green:

tokenizer = AutoTokenizer.from_pretrained('openai/gpt-oss-20b', trust_remote_code=True)
tokenizer.chat_template = CORRECTED_JINJA_TEMPLATE

templated_output = tokenizer.apply_chat_template(
    sample['messages'],
    tokenize=True, 
    add_generation_prompt=False,
    return_assistant_tokens_mask=True,
    return_dict=True,
)

print("Visualizing token masks. Green text is used for loss calculation.\n")
GREEN = "\033[92m"
RESET = "\033[0m"

input_ids = templated_output['input_ids']
assistant_mask = templated_output['assistant_masks']

if len(input_ids) != len(assistant_mask):
    raise ValueError("Mismatch between input_ids and assistant_masks length.")

current_chunk_tokens = []
current_mask_status = None

for token_id, is_assistant in zip(input_ids, assistant_mask):
    mask_status = bool(is_assistant)
    if current_mask_status is None:
        current_mask_status = mask_status
    
    if mask_status != current_mask_status:
        # Decode and print the completed chunk
        decoded_text = tokenizer.decode(current_chunk_tokens, skip_special_tokens=False)
        if current_mask_status:
            print(f"{GREEN}{decoded_text}{RESET}", end="")
        else:
            print(decoded_text, end="")
        
        # Start a new chunk
        current_chunk_tokens = [token_id]
        current_mask_status = mask_status
    else:
        current_chunk_tokens.append(token_id)

# Print the final chunk after the loop
if current_chunk_tokens:
    decoded_text = tokenizer.decode(current_chunk_tokens, skip_special_tokens=False)
    if current_mask_status:
        print(f"{GREEN}{decoded_text}{RESET}", end="")
    else:
        print(decoded_text, end="")

Prints something like:

<|start|>user<|message|>USER_MESSAGE<|end|>[GREEN_STARTS]<|start|>assistant<|channel|>analysis<|message|>...<|call|>[GREEN_ENDS]
Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment