TRL documentation

WinRateCallback

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.25.1).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

WinRateCallback

class trl.WinRateCallback

< >

( judge trainer: Trainer generation_config: transformers.generation.configuration_utils.GenerationConfig | None = None num_prompts: int | None = None shuffle_order: bool = True use_soft_judge: bool = False )

Parameters

  • judge (experimental.judges.BasePairwiseJudge) — The judge to use for comparing completions.
  • trainer (Trainer) — Trainer to which the callback will be attached. The trainer’s evaluation dataset must include a "prompt" column containing the prompts for generating completions. If the Trainer has a reference model (via the ref_model attribute), it will use this reference model for generating the reference completions; otherwise, it defaults to using the initial model.
  • generation_config (GenerationConfig, optional) — The generation config to use for generating completions.
  • num_prompts (int, optional) — The number of prompts to generate completions for. If not provided, defaults to the number of examples in the evaluation dataset.
  • shuffle_order (bool, optional, defaults to True) — Whether to shuffle the order of the completions before judging.
  • use_soft_judge (bool, optional, defaults to False) — Whether to use a soft judge that returns a win probability between 0 and 1 for the first completion vs the second.

A TrainerCallback that computes the win rate of a model based on a reference.

It generates completions using prompts from the evaluation dataset and compares the trained model’s outputs against a reference. The reference is either the initial version of the model (before training) or the reference model, if available in the trainer. During each evaluation step, a judge determines how often the trained model’s completions win against the reference using a judge. The win rate is then logged in the trainer’s logs under the key "eval_win_rate".

Usage:

trainer = DPOTrainer(...)
judge = PairRMJudge()
win_rate_callback = WinRateCallback(judge=judge, trainer=trainer)
trainer.add_callback(win_rate_callback)
Update on GitHub