metadata
base_model: llm-jp/llm-jp-3-13b-instruct2
tags:
- text-generation-inference
- transformers
- unsloth
- llama
- trl
license: apache-2.0
language:
- en
morizon/llm-jp-3-13b-instruct2-grpo-0215_lora
このモデルは日本語テキスト生成タスク向けに最適化されたLoRAアダプタ付きのモデルです。
- Developed by: morizon
- License: apache-2.0
- Finetuned from model : llm-jp/llm-jp-3-13b-instruct2
This llama model was trained 2x faster with Unsloth and Huggingface's TRL library.
Sample Use
%%capture
# Skip restarting message in Colab
import sys; modules = list(sys.modules.keys())
for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
!pip install unsloth vllm
!pip install --upgrade pillow
# If you are running this notebook on local, you need to install `diffusers` too
# !pip install diffusers
# Temporarily install a specific TRL nightly version
!pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b
from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
from unsloth import is_bfloat16_supported
model_id="llm-jp/llm-jp-3-13b-instruct2"
adpter_id="morizon/llm-jp-3-13b-instruct2-grpo-MATH-lighteval_step1000_lora"
# --- モデルの読み込みと LoRA 適用 ---
max_seq_length = 1024 # 推論トレースの最大長
lora_rank = 64 # LoRA のランク(推奨値:64)
# FastLanguageModel 経由でモデルとトークナイザーを読み込み
# ※ モデル名は使用するものに合わせてください
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_id,
max_seq_length=max_seq_length,
load_in_4bit=True, # 4bit量子化(LoRAファインチューニング時は設定に注意)
fast_inference=True, # vLLM 高速推論を有効化
max_lora_rank=lora_rank,
gpu_memory_utilization=0.7,
)
# LoRA (PEFT) を適用
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_alpha=lora_rank,
use_gradient_checkpointing="unsloth",
random_state=3407,
)
# --- プロンプトとデータセットの準備 ---
# 推奨:システムプロンプトを排除し、ユーザープロンプトに全指示を統合
USER_INSTRUCTION = (
"Please ensure your response begins with \"<reasoning>\n\". "
"Please reason step by step, and put your final answer within \\boxed{}. "
)
# テストデータの例(リスト形式)
test_data = [
{"id": 0, "text": "$x^{-1}>x$を満たす正の整数$x$の個数を求めなさい。", "gold": "0", "response": "", "type": "Algebra", "level": "Level 2"},
##評価したいテストデータを入力してください
]
def extract_boxed_answer_rev(text: str) -> str:
"""
テキスト中から最初の \boxed{...} の中身(ネストを考慮)を抽出する。
例: r"\boxed{\frac{\pi}{6}}" -> "\frac{\pi}{6}"
"""
key = r"\boxed{"
start_idx = text.find(key)
if start_idx == -1:
return ""
# \boxed{ の直後の位置を開始位置とする
start_idx += len(key)
brace_count = 1 # 最初の { を既にカウント
i = start_idx
while i < len(text) and brace_count > 0:
if text[i] == "{":
brace_count += 1
elif text[i] == "}":
brace_count -= 1
i += 1
# i-1 が閉じ括弧に対応する位置
return text[start_idx:i-1].strip()
from vllm import SamplingParams
correct = 0
total = len(test_data)
# 正解ケースと誤答ケースを記録するリスト
correct_cases = []
incorrect_cases = []
for item in test_data:
# プロンプト生成(USER_INSTRUCTION を先頭に追加)
prompt = USER_INSTRUCTION + item["text"]
text = tokenizer.apply_chat_template([
{"role": "user", "content": prompt},
], tokenize=False, add_generation_prompt=True)
# 推論実行
sampling_params = SamplingParams(
temperature=0.6,
max_tokens=2048,
)
output = model.fast_generate(
text,
sampling_params=sampling_params,
lora_request = model.load_lora(adpter_id),
# lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text
# \boxed{...} の中身を抽出する関数で回答を取得
boxed_answer = extract_boxed_answer_rev(output)
# 結果の表示用
print("\n----------Test ID:", item["id"], "----------")
print("Prompt:")
print(prompt)
print("\nLLM Output:")
print(output)
print("\nExtracted Answer:")
print(boxed_answer)
print("Gold Answer:", item["gold"])
# 抽出回答と gold の一致で正解判定
if boxed_answer == item["gold"]:
correct += 1
correct_cases.append({
"id": item["id"],
"prompt": prompt,
"LLM_output": output,
"extracted_answer": boxed_answer,
"gold": item["gold"]
})
else:
incorrect_cases.append({
"id": item["id"],
"prompt": prompt,
"LLM_output": output,
"extracted_answer": boxed_answer,
"gold": item["gold"]
})
# 正解ケースの表示
print("\n========== 正解ケース ==========")
for case in correct_cases:
print("\nTest ID:", case["id"])
print("Prompt:")
print(case["prompt"])
print("LLM Output:")
print(case["LLM_output"])
print("Extracted Answer:", case["extracted_answer"])
print("Gold Answer:", case["gold"])
print("-" * 40)
# 誤答ケースの表示
print("\n========== 誤答ケース ==========")
for case in incorrect_cases:
print("\nTest ID:", case["id"])
print("Prompt:")
print(case["prompt"])
print("LLM Output:")
print(case["LLM_output"])
print("Extracted Answer:", case["extracted_answer"])
print("Gold Answer:", case["gold"])
print("-" * 40)
accuracy = correct / total * 100
print("\nOverall Accuracy: {}/{} ({:.2f}%)".format(correct, total, accuracy))