TRL documentation
Reward Functions
Reward Functions
This module contains some useful reward functions, primarily intended for use with the GRPOTrainer and RLOOTrainer.
accuracy_reward
trl.rewards.accuracy_reward
< source >( completions: list solution: list **kwargs )
Parameters
- completions (
list[list[dict[str, str]]]
) — List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary containing the key"content"
with the value being the text of the completion. - solution — (
list[str]
): List of the raw-text solutions to the questions/problems/prompts. - **kwargs — Additional keyword arguments. This function does not use them, but they are required in the function signature to ensure compatibility with trainers like GRPOTrainer.
Reward function that checks if the completion is the same as the ground truth.
- If both gold and prediction are parseable → use math verification.
- If not parseable → compare as normalized text.
Example:
>>> from trl.rewards import accuracy_reward
>>> solution = [r"\frac{1}{3}", r"\frac{1}{3}"]
>>> completion = [
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{3}}"}],
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{2}}"}],
... ]
>>> accuracy_reward(completion, solution)
[1.0, 0.0]
think_format_reward
trl.rewards.think_format_reward
< source >( completions: list **kwargs ) → list[float]
Parameters
- completions (
list[list[dict[str, str]]]
) — List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary containing the key"content"
with the value being the text of the completion. - **kwargs — Additional keyword arguments. This function does not use them, but they are required in the function signature to ensure compatibility with trainers like GRPOTrainer.
Returns
list[float]
A list of rewards, where each reward is 1.0 if the completion matches the expected format, otherwise 0.0.
Reward function that checks if the reasoning process is enclosed within "<think>"
and "</think>"
tags. The
function returns a reward of 1.0 if the format is correct, otherwise 0.0.
get_soft_overlong_punishment
trl.rewards.get_soft_overlong_punishment
< source >( max_completion_len: int soft_punish_cache: int )
Reward function that penalizes overlong completions. It is used to penalize overlong completions, but not to reward shorter completions. Reference: Eq. (13) from the DAPO paper (https://huggingface.co/papers/2503.14476)
Example:
from trl.rewards import get_soft_overlong_punishment
soft_overlong_punishment = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
completion_ids = [[1] * 90] # simulating a completion with 90 tokens. 90 is between 80 and 100.
rewards = soft_overlong_punishment(completion_ids)
print(rewards) # [-0.5]