TRL documentation
WinRateCallback
WinRateCallback
class trl.WinRateCallback
< source >( judge trainer: Trainer generation_config: transformers.generation.configuration_utils.GenerationConfig | None = None num_prompts: int | None = None shuffle_order: bool = True use_soft_judge: bool = False )
Parameters
- judge (experimental.judges.BasePairwiseJudge) — The judge to use for comparing completions.
- trainer (
Trainer) — Trainer to which the callback will be attached. The trainer’s evaluation dataset must include a"prompt"column containing the prompts for generating completions. If theTrainerhas a reference model (via theref_modelattribute), it will use this reference model for generating the reference completions; otherwise, it defaults to using the initial model. - generation_config (GenerationConfig, optional) — The generation config to use for generating completions.
- num_prompts (
int, optional) — The number of prompts to generate completions for. If not provided, defaults to the number of examples in the evaluation dataset. - shuffle_order (
bool, optional, defaults toTrue) — Whether to shuffle the order of the completions before judging. - use_soft_judge (
bool, optional, defaults toFalse) — Whether to use a soft judge that returns a win probability between 0 and 1 for the first completion vs the second.
A TrainerCallback that computes the win rate of a model based on a reference.
It generates completions using prompts from the evaluation dataset and compares the trained model’s outputs against
a reference. The reference is either the initial version of the model (before training) or the reference model, if
available in the trainer. During each evaluation step, a judge determines how often the trained model’s completions
win against the reference using a judge. The win rate is then logged in the trainer’s logs under the key
"eval_win_rate".