Distractor Generation with T5-base
This repository contains a T5-base model fine-tuned for distractor generation. Leveraging T5’s text-to-text framework and a custom separator token, the model generates three plausible distractors for multiple-choice questions by conditioning on a given question, context, and correct answer.
Model Overview
Built with PyTorch Lightning, this implementation fine-tunes the pre-trained T5-base model to generate distractor options. The model takes a single input sequence formatted with the question, context, and correct answer—separated by a custom token—and generates a target sequence containing three distractors. This approach is particularly useful for multiple-choice question generation tasks.
Data Processing
Input Construction
Each input sample is a single string with the following format:
question {SEP_TOKEN} context {SEP_TOKEN} correct
- question: The question text.
- context: The context passage.
- correct: The correct answer.
- SEP_TOKEN: A special token added to the tokenizer to separate the different fields.
Target Construction
Each target sample is constructed as follows:
incorrect1 {SEP_TOKEN} incorrect2 {SEP_TOKEN} incorrect3
This format allows the model to generate three distractors in one pass.
Training Details
- Framework: PyTorch Lightning
- Base Model: T5-base
- Optimizer: Adam with linear scheduling (using a warmup scheduler)
- Batch Size: 32
- Number of Epochs: 5
- Learning Rate: 2e-5
- Tokenization:
- Input: Maximum length of 512 tokens
- Target: Maximum length of 64 tokens
- Special Tokens: The custom
SEP_TOKEN
is added to the tokenizer and is used to separate different parts of the input and target sequences.
Evaluation Metrics
The model is evaluated using BLEU scores for each generated distractor. Below are the BLEU scores obtained on the test set:
Distractor | BLEU-1 | BLEU-2 | BLEU-3 | BLEU-4 |
---|---|---|---|---|
Distractor 1 | 29.59 | 21.55 | 17.86 | 15.75 |
Distractor 2 | 25.21 | 16.81 | 13.00 | 10.78 |
Distractor 3 | 23.99 | 15.78 | 12.35 | 10.52 |
These scores indicate that the model is capable of generating distractors with high n‑gram overlap compared to reference distractors.
How to Use
You can use this model with Hugging Face's Transformers pipeline as follows:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name = "fares7elsadek/t5-base-distractor-generation"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
SEP_TOKEN = "<sep>"
def generate_distractors(question, context, correct, max_length=64):
input_text = f"{question} {SEP_TOKEN} {context} {SEP_TOKEN} {correct}"
inputs = tokenizer([input_text], return_tensors="pt", truncation=True, padding=True)
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
distractors = [d.strip() for d in decoded.split(SEP_TOKEN)]
return distractors
# Example usage:
question = "What is the capital of France?"
context = "France is a country in Western Europe known for its rich history and cultural heritage."
correct = "Paris"
print(generate_distractors(question, context, correct))
- Downloads last month
- 42
Model tree for fares7elsadek/t5-base-distractor-generation
Base model
google-t5/t5-base