Rename Reinforcement_Learning_model_using_mini_safetesor_model.py to Reinforcement_Learning_model_using_mini_safetesor_model_using_transformer.py
Browse files
Reinforcement_Learning_model_using_mini_safetesor_model.py
DELETED
File without changes
|
Reinforcement_Learning_model_using_mini_safetesor_model_using_transformer.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ppo_train_safetensors.py
|
3 |
+
Minimal PPO example using TRL (Hugging Face) with a small LM.
|
4 |
+
This script:
|
5 |
+
- loads a small causal LM (distilgpt2 / gpt2)
|
6 |
+
- wraps it with a value head for PPO
|
7 |
+
- runs a tiny RL loop with a dummy reward function
|
8 |
+
- saves the policy as .safetensors
|
9 |
+
|
10 |
+
Notes:
|
11 |
+
- Replace `dummy_reward` with a real reward function or reward model for serious experiments.
|
12 |
+
- For better VRAM usage, consider bitsandbytes + load_in_8bit (not shown).
|
13 |
+
"""
|
14 |
+
|
15 |
+
import os
|
16 |
+
import torch
|
17 |
+
from transformers import AutoTokenizer
|
18 |
+
from trl import (
|
19 |
+
PPOTrainer,
|
20 |
+
PPOConfig,
|
21 |
+
AutoModelForCausalLMWithValueHead,
|
22 |
+
create_reference_model,
|
23 |
+
)
|
24 |
+
from datasets import Dataset
|
25 |
+
import random
|
26 |
+
|
27 |
+
# ---------- CONFIG ----------
|
28 |
+
MODEL_NAME = "distilgpt2" # small model for toy experiments
|
29 |
+
OUTPUT_DIR = "./ppo_policy"
|
30 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
+
BATCH_SIZE = 2 # keep tiny for demonstration
|
32 |
+
EPOCHS = 2
|
33 |
+
MAX_GEN_TOKENS = 64
|
34 |
+
SEED = 42
|
35 |
+
|
36 |
+
# ---------- HELPERS ----------
|
37 |
+
def set_seed(seed: int = 42):
|
38 |
+
random.seed(seed)
|
39 |
+
torch.manual_seed(seed)
|
40 |
+
if torch.cuda.is_available():
|
41 |
+
torch.cuda.manual_seed_all(seed)
|
42 |
+
|
43 |
+
def dummy_reward(prompt: str, response: str) -> float:
|
44 |
+
"""
|
45 |
+
Toy reward: +1 if response contains the word 'def' (i.e., generates code-like text),
|
46 |
+
otherwise -0.1. Replace with a reward model for real work.
|
47 |
+
"""
|
48 |
+
r = 1.0 if "def " in response or "def\n" in response else -0.1
|
49 |
+
# add small length bonus to encourage non-empty responses
|
50 |
+
r += min(len(response.split()), 10) * 0.01
|
51 |
+
return float(r)
|
52 |
+
|
53 |
+
# ---------- MAIN ----------
|
54 |
+
def main():
|
55 |
+
set_seed(SEED)
|
56 |
+
|
57 |
+
# 1) tokenizer + model (with value head)
|
58 |
+
print("Loading tokenizer and model...")
|
59 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
60 |
+
# ensure pad token
|
61 |
+
if tokenizer.pad_token is None:
|
62 |
+
tokenizer.pad_token = tokenizer.eos_token
|
63 |
+
|
64 |
+
# AutoModelForCausalLMWithValueHead integrates value head used by PPO
|
65 |
+
model = AutoModelForCausalLMWithValueHead.from_pretrained(MODEL_NAME).to(DEVICE)
|
66 |
+
|
67 |
+
# 2) PPO config and trainer
|
68 |
+
ppo_config = PPOConfig(
|
69 |
+
model_name=MODEL_NAME,
|
70 |
+
batch_size=BATCH_SIZE,
|
71 |
+
ppo_epochs=1,
|
72 |
+
learning_rate=1.41e-5,
|
73 |
+
log_with=None, # or "wandb"
|
74 |
+
)
|
75 |
+
|
76 |
+
# Create a reference model used by PPO to compute KL / keep policy close to reference
|
77 |
+
reference_model = create_reference_model(model)
|
78 |
+
|
79 |
+
ppo_trainer = PPOTrainer(
|
80 |
+
model=model,
|
81 |
+
reference_model=reference_model,
|
82 |
+
tokenizer=tokenizer,
|
83 |
+
**ppo_config.__dict__,
|
84 |
+
)
|
85 |
+
|
86 |
+
# 3) small prompt dataset (toy)
|
87 |
+
prompts = [
|
88 |
+
"Write a Python function to compute factorial using recursion:",
|
89 |
+
"Write a Python function to check if a number is prime:",
|
90 |
+
"Write a Python function to compute Fibonacci numbers:",
|
91 |
+
"Explain quicksort in short steps:",
|
92 |
+
"Write a Python function to reverse a string:"
|
93 |
+
]
|
94 |
+
# duplicate to reach dataset size and shuffle
|
95 |
+
prompts = prompts * 20
|
96 |
+
random.shuffle(prompts)
|
97 |
+
|
98 |
+
# 4) RL loop
|
99 |
+
print("Starting PPO loop...")
|
100 |
+
for epoch in range(EPOCHS):
|
101 |
+
print(f"Epoch {epoch+1}/{EPOCHS}")
|
102 |
+
# iterate in batches
|
103 |
+
for i in range(0, len(prompts), BATCH_SIZE):
|
104 |
+
batch_prompts = prompts[i:i + BATCH_SIZE]
|
105 |
+
# encode prompts
|
106 |
+
batch_encoding = tokenizer(batch_prompts, return_tensors="pt", padding=True).to(DEVICE)
|
107 |
+
|
108 |
+
# generate responses from current policy via trainer (uses model.generate)
|
109 |
+
# Use ppo_trainer.generate if available; fallback to model.generate
|
110 |
+
with torch.no_grad():
|
111 |
+
gen_ids = model.generate(
|
112 |
+
**batch_encoding,
|
113 |
+
max_length=batch_encoding["input_ids"].shape[1] + MAX_GEN_TOKENS,
|
114 |
+
do_sample=True,
|
115 |
+
top_p=0.9,
|
116 |
+
temperature=1.0,
|
117 |
+
pad_token_id=tokenizer.eos_token_id,
|
118 |
+
)
|
119 |
+
|
120 |
+
# extract only generated part (strip the prompt tokens)
|
121 |
+
responses = []
|
122 |
+
for idx, g in enumerate(gen_ids):
|
123 |
+
g = g.tolist()
|
124 |
+
# find where prompt ends (simple approach)
|
125 |
+
inp_len = (batch_encoding["input_ids"][idx] == tokenizer.eos_token_id).nonzero(as_tuple=False)
|
126 |
+
# safer: just decode and remove prompt text
|
127 |
+
full_text = tokenizer.decode(g, skip_special_tokens=True)
|
128 |
+
# remove the prompt text prefix if present
|
129 |
+
prompt_text = batch_prompts[idx]
|
130 |
+
if full_text.startswith(prompt_text):
|
131 |
+
response_text = full_text[len(prompt_text):].strip()
|
132 |
+
else:
|
133 |
+
# fallback: token-level slicing using lengths
|
134 |
+
response_text = tokenizer.decode(g[len(batch_encoding["input_ids"][idx]):], skip_special_tokens=True)
|
135 |
+
responses.append(response_text)
|
136 |
+
|
137 |
+
# compute rewards (list of floats)
|
138 |
+
rewards = [dummy_reward(p, r) for p, r in zip(batch_prompts, responses)]
|
139 |
+
|
140 |
+
# convert responses to tokens for PPO step
|
141 |
+
# ppo_trainer.step expects raw prompts and generated tokens (or text), API may accept text directly
|
142 |
+
try:
|
143 |
+
# TRL commonly offers ppo_trainer.step(prompts, responses, rewards)
|
144 |
+
ppo_trainer.step(batch_prompts, responses, rewards)
|
145 |
+
except TypeError:
|
146 |
+
# Some TRL versions expect tokenized inputs; fallback to lower-level API
|
147 |
+
# Convert to list of token sequences
|
148 |
+
resp_token_ids = [tokenizer.encode(r) for r in responses]
|
149 |
+
ppo_trainer.step(batch_prompts, resp_token_ids, rewards)
|
150 |
+
|
151 |
+
if (i // BATCH_SIZE) % 10 == 0:
|
152 |
+
print(f" processed batch {i // BATCH_SIZE} - avg reward {sum(rewards)/len(rewards):.3f}")
|
153 |
+
|
154 |
+
# 5) Save final policy as safetensors
|
155 |
+
print("Saving final model (policy) to safetensors...")
|
156 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
157 |
+
# save_pretrained supports safe_serialization=True to write .safetensors
|
158 |
+
model.save_pretrained(OUTPUT_DIR, safe_serialization=True)
|
159 |
+
tokenizer.save_pretrained(OUTPUT_DIR)
|
160 |
+
print("Saved to", OUTPUT_DIR)
|
161 |
+
print("Done.")
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
main()
|