morizon's picture
Update README.md
51e0e51 verified
metadata
base_model: llm-jp/llm-jp-3-13b-instruct2
tags:
  - text-generation-inference
  - transformers
  - unsloth
  - llama
  - trl
license: apache-2.0
language:
  - en

morizon/llm-jp-3-13b-instruct2-grpo-0215_lora

このモデルは日本語テキスト生成タスク向けに最適化されたLoRAアダプタ付きのモデルです。

  • Developed by: morizon
  • License: apache-2.0
  • Finetuned from model : llm-jp/llm-jp-3-13b-instruct2

This llama model was trained 2x faster with Unsloth and Huggingface's TRL library.

Sample Use


%%capture
# Skip restarting message in Colab
import sys; modules = list(sys.modules.keys())
for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None

!pip install unsloth vllm
!pip install --upgrade pillow
# If you are running this notebook on local, you need to install `diffusers` too
# !pip install diffusers
# Temporarily install a specific TRL nightly version
!pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b
from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)

import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
from unsloth import is_bfloat16_supported
model_id="llm-jp/llm-jp-3-13b-instruct2"
adpter_id="morizon/llm-jp-3-13b-instruct2-grpo-MATH-lighteval_step1000_lora"

# --- モデルの読み込みと LoRA 適用 ---
max_seq_length = 1024  # 推論トレースの最大長
lora_rank = 64         # LoRA のランク(推奨値:64)

# FastLanguageModel 経由でモデルとトークナイザーを読み込み
# ※ モデル名は使用するものに合わせてください
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_id,
    max_seq_length=max_seq_length,
    load_in_4bit=True,      # 4bit量子化(LoRAファインチューニング時は設定に注意)
    fast_inference=True,    # vLLM 高速推論を有効化
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.7,
)

# LoRA (PEFT) を適用
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=lora_rank,
    use_gradient_checkpointing="unsloth",
    random_state=3407,
)
# --- プロンプトとデータセットの準備 ---
# 推奨:システムプロンプトを排除し、ユーザープロンプトに全指示を統合
USER_INSTRUCTION = (
    "Please ensure your response begins with \"<reasoning>\n\". "
    "Please reason step by step, and put your final answer within \\boxed{}. "
)

# テストデータの例(リスト形式)
test_data = [
    {"id": 0, "text": "$x^{-1}>x$を満たす正の整数$x$の個数を求めなさい。", "gold": "0", "response": "", "type": "Algebra", "level": "Level 2"},
    ##評価したいテストデータを入力してください
]

def extract_boxed_answer_rev(text: str) -> str:
    """
    テキスト中から最初の \boxed{...} の中身(ネストを考慮)を抽出する。
    例: r"\boxed{\frac{\pi}{6}}" -> "\frac{\pi}{6}"
    """
    key = r"\boxed{"
    start_idx = text.find(key)
    if start_idx == -1:
        return ""
    # \boxed{ の直後の位置を開始位置とする
    start_idx += len(key)
    brace_count = 1  # 最初の { を既にカウント
    i = start_idx
    while i < len(text) and brace_count > 0:
        if text[i] == "{":
            brace_count += 1
        elif text[i] == "}":
            brace_count -= 1
        i += 1
    # i-1 が閉じ括弧に対応する位置
    return text[start_idx:i-1].strip()

from vllm import SamplingParams

correct = 0
total = len(test_data)

# 正解ケースと誤答ケースを記録するリスト
correct_cases = []
incorrect_cases = []

for item in test_data:
    # プロンプト生成(USER_INSTRUCTION を先頭に追加)
    prompt = USER_INSTRUCTION + item["text"]
    text = tokenizer.apply_chat_template([
        {"role": "user", "content": prompt},
    ], tokenize=False, add_generation_prompt=True)

    # 推論実行
    sampling_params = SamplingParams(
        temperature=0.6,
        max_tokens=2048,
    )
    output = model.fast_generate(
        text,
        sampling_params=sampling_params,
        lora_request = model.load_lora(adpter_id),
        # lora_request = model.load_lora("grpo_saved_lora"),
    )[0].outputs[0].text

    # \boxed{...} の中身を抽出する関数で回答を取得
    boxed_answer = extract_boxed_answer_rev(output)

    # 結果の表示用
    print("\n----------Test ID:", item["id"], "----------")
    print("Prompt:")
    print(prompt)
    print("\nLLM Output:")
    print(output)
    print("\nExtracted Answer:")
    print(boxed_answer)
    print("Gold Answer:", item["gold"])

    # 抽出回答と gold の一致で正解判定
    if boxed_answer == item["gold"]:
        correct += 1
        correct_cases.append({
            "id": item["id"],
            "prompt": prompt,
            "LLM_output": output,
            "extracted_answer": boxed_answer,
            "gold": item["gold"]
        })
    else:
        incorrect_cases.append({
            "id": item["id"],
            "prompt": prompt,
            "LLM_output": output,
            "extracted_answer": boxed_answer,
            "gold": item["gold"]
        })

# 正解ケースの表示
print("\n========== 正解ケース ==========")
for case in correct_cases:
    print("\nTest ID:", case["id"])
    print("Prompt:")
    print(case["prompt"])
    print("LLM Output:")
    print(case["LLM_output"])
    print("Extracted Answer:", case["extracted_answer"])
    print("Gold Answer:", case["gold"])
    print("-" * 40)

# 誤答ケースの表示
print("\n========== 誤答ケース ==========")
for case in incorrect_cases:
    print("\nTest ID:", case["id"])
    print("Prompt:")
    print(case["prompt"])
    print("LLM Output:")
    print(case["LLM_output"])
    print("Extracted Answer:", case["extracted_answer"])
    print("Gold Answer:", case["gold"])
    print("-" * 40)

accuracy = correct / total * 100
print("\nOverall Accuracy: {}/{} ({:.2f}%)".format(correct, total, accuracy))