|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- PKU-ML/Erdos-CoT |
|
language: |
|
- en |
|
metrics: |
|
- accuracy |
|
base_model: |
|
- Qwen/Qwen2.5-3B-Instruct |
|
pipeline_tag: text-generation |
|
tags: |
|
- graph |
|
- chat |
|
library_name: transformers |
|
--- |
|
|
|
|
|
# G1-CoT-SFT-3B |
|
|
|
## Introduction |
|
|
|
G1 is the series of large language models trained on our benchmark [Erdos](https://huggingface.co/datasets/PKU-ML/Erdos) for solving graph reasoning tasks, based on Qwen2.5-Instruct. |
|
We apply Group Relative Policy Optimization (GRPO) for reinforcement learning with supervised finetuning as a prelimary step. |
|
|
|
G1 brings the following improvements: |
|
|
|
- **Significant improvement on graph reasoning**: G1 models achieve up to 46% improvement over baselines on Erdős, with the 7B variant matching OpenAI’s o3-mini and the 3B model surpassing Qwen2.5-72B-Instruct by notable margins. |
|
- **Strong Generalization to unseen graph tasks**: G1 exhibits zero-shot generalization on unseen graph tasks, improving performance on *other graph reasoning benchmarks* (GraphWiz, GraphArena) and *real-world graphs* (Cora, PubMed). |
|
- **NO Compromise on general reasoning**: Crucially, G1 preserves general reasoning ability (GSM8K, MATH, MMLU-Pro), proving its versatility. |
|
|
|
|
|
**This repo contains the G1-CoT-SFT-3B model**, which has the following features: |
|
- Type: Causal Language Models |
|
- Training Stage: SFT |
|
- Architecture: the same with Qwen2.5-Instruct |
|
- Number of Parameters: 3.09B |
|
- Context Length: Full 32,768 tokens and generation 8192 tokens |
|
|
|
For more details, please refer to our [paper](https://arxiv.org/pdf/2505.18499) and [GitHub](https://github.com/PKU-ML/G1/tree/main). |
|
|
|
|
|
## Requirements |
|
|
|
The model is trained based on Qwen/Qwen2.5-3B-Instruct. The code of Qwen2.5 has been in the latest Hugging face `transformers` and we advise you to use the latest version of `transformers`. |
|
|
|
With `transformers<4.37.0`, you will encounter the following error: |
|
``` |
|
KeyError: 'qwen2' |
|
``` |
|
|
|
|
|
## Quickstart |
|
|
|
Here provides a code snippet with `apply_chat_template` to show you how to load the tokenizer and model and how to generate contents. |
|
|
|
```python |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
INSTRUCTION_TEMPLATE = """ |
|
{instruction} |
|
|
|
Solve the above problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering. |
|
""".strip() |
|
|
|
model_name = "PKU-ML/G1-CoT-SFT-3B" |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype="auto", |
|
device_map="auto" |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
prompt = "The task is to determine the degree centrality of a node in the graph.\n\n"\ |
|
"Degree centrality for a node is the fraction of nodes it is connected to.\n\n"\ |
|
"Here is an undirected graph containing nodes from 1 to 15. The edges are: (1, 15), (15, 11), (2, 3), (2, 6), (3, 6), (3, 7), (6, 7), (6, 8), (7, 8), (7, 14), (4, 10), (10, 5), (10, 12), (8, 14), (8, 9), (12, 11), (12, 13).\n\n"\ |
|
"Question: What is the degree centrality of node 2 in the graph?\n\n"\ |
|
"You need to format your answer as a float number." |
|
messages = [ |
|
{"role": "user", "content": INSTRUCTION_TEMPLATE.format(instruction=prompt)} |
|
] |
|
text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) |
|
|
|
generated_ids = model.generate( |
|
**model_inputs, |
|
max_new_tokens=4096, |
|
top_p=0.95, |
|
top_k=30, |
|
temperature=0.6 |
|
) |
|
generated_ids = [ |
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) |
|
] |
|
|
|
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
print(response) |
|
``` |
|
|
|
|
|
## Evaluation & Performance |
|
|
|
Detailed evaluation results are reported in this [📑 paper](https://arxiv.org/pdf/2505.18499). |
|
|
|
|
|
## Citation |
|
|
|
If you find our work helpful, feel free to give us a cite. |
|
|
|
``` |
|
@article{guo2025g1, |
|
title={G1: Teaching LLMs to Reason on Graphs with Reinforcement Learning}, |
|
author={Guo, Xiaojun and Li, Ang and Wang, Yifei and Jegelka, Stefanie and Wang, Yisen}, |
|
journal={arXiv preprint arXiv:2505.18499}, |
|
year={2025} |
|
} |
|
``` |