ankitkushwaha90 commited on
Commit
4834c9d
·
verified ·
1 Parent(s): c22365d

Rename Reinforcement_Learning_model_using_mini_safetesor_model.py to Reinforcement_Learning_model_using_mini_safetesor_model_using_transformer.py

Browse files
Reinforcement_Learning_model_using_mini_safetesor_model.py DELETED
File without changes
Reinforcement_Learning_model_using_mini_safetesor_model_using_transformer.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ppo_train_safetensors.py
3
+ Minimal PPO example using TRL (Hugging Face) with a small LM.
4
+ This script:
5
+ - loads a small causal LM (distilgpt2 / gpt2)
6
+ - wraps it with a value head for PPO
7
+ - runs a tiny RL loop with a dummy reward function
8
+ - saves the policy as .safetensors
9
+
10
+ Notes:
11
+ - Replace `dummy_reward` with a real reward function or reward model for serious experiments.
12
+ - For better VRAM usage, consider bitsandbytes + load_in_8bit (not shown).
13
+ """
14
+
15
+ import os
16
+ import torch
17
+ from transformers import AutoTokenizer
18
+ from trl import (
19
+ PPOTrainer,
20
+ PPOConfig,
21
+ AutoModelForCausalLMWithValueHead,
22
+ create_reference_model,
23
+ )
24
+ from datasets import Dataset
25
+ import random
26
+
27
+ # ---------- CONFIG ----------
28
+ MODEL_NAME = "distilgpt2" # small model for toy experiments
29
+ OUTPUT_DIR = "./ppo_policy"
30
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
+ BATCH_SIZE = 2 # keep tiny for demonstration
32
+ EPOCHS = 2
33
+ MAX_GEN_TOKENS = 64
34
+ SEED = 42
35
+
36
+ # ---------- HELPERS ----------
37
+ def set_seed(seed: int = 42):
38
+ random.seed(seed)
39
+ torch.manual_seed(seed)
40
+ if torch.cuda.is_available():
41
+ torch.cuda.manual_seed_all(seed)
42
+
43
+ def dummy_reward(prompt: str, response: str) -> float:
44
+ """
45
+ Toy reward: +1 if response contains the word 'def' (i.e., generates code-like text),
46
+ otherwise -0.1. Replace with a reward model for real work.
47
+ """
48
+ r = 1.0 if "def " in response or "def\n" in response else -0.1
49
+ # add small length bonus to encourage non-empty responses
50
+ r += min(len(response.split()), 10) * 0.01
51
+ return float(r)
52
+
53
+ # ---------- MAIN ----------
54
+ def main():
55
+ set_seed(SEED)
56
+
57
+ # 1) tokenizer + model (with value head)
58
+ print("Loading tokenizer and model...")
59
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
60
+ # ensure pad token
61
+ if tokenizer.pad_token is None:
62
+ tokenizer.pad_token = tokenizer.eos_token
63
+
64
+ # AutoModelForCausalLMWithValueHead integrates value head used by PPO
65
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(MODEL_NAME).to(DEVICE)
66
+
67
+ # 2) PPO config and trainer
68
+ ppo_config = PPOConfig(
69
+ model_name=MODEL_NAME,
70
+ batch_size=BATCH_SIZE,
71
+ ppo_epochs=1,
72
+ learning_rate=1.41e-5,
73
+ log_with=None, # or "wandb"
74
+ )
75
+
76
+ # Create a reference model used by PPO to compute KL / keep policy close to reference
77
+ reference_model = create_reference_model(model)
78
+
79
+ ppo_trainer = PPOTrainer(
80
+ model=model,
81
+ reference_model=reference_model,
82
+ tokenizer=tokenizer,
83
+ **ppo_config.__dict__,
84
+ )
85
+
86
+ # 3) small prompt dataset (toy)
87
+ prompts = [
88
+ "Write a Python function to compute factorial using recursion:",
89
+ "Write a Python function to check if a number is prime:",
90
+ "Write a Python function to compute Fibonacci numbers:",
91
+ "Explain quicksort in short steps:",
92
+ "Write a Python function to reverse a string:"
93
+ ]
94
+ # duplicate to reach dataset size and shuffle
95
+ prompts = prompts * 20
96
+ random.shuffle(prompts)
97
+
98
+ # 4) RL loop
99
+ print("Starting PPO loop...")
100
+ for epoch in range(EPOCHS):
101
+ print(f"Epoch {epoch+1}/{EPOCHS}")
102
+ # iterate in batches
103
+ for i in range(0, len(prompts), BATCH_SIZE):
104
+ batch_prompts = prompts[i:i + BATCH_SIZE]
105
+ # encode prompts
106
+ batch_encoding = tokenizer(batch_prompts, return_tensors="pt", padding=True).to(DEVICE)
107
+
108
+ # generate responses from current policy via trainer (uses model.generate)
109
+ # Use ppo_trainer.generate if available; fallback to model.generate
110
+ with torch.no_grad():
111
+ gen_ids = model.generate(
112
+ **batch_encoding,
113
+ max_length=batch_encoding["input_ids"].shape[1] + MAX_GEN_TOKENS,
114
+ do_sample=True,
115
+ top_p=0.9,
116
+ temperature=1.0,
117
+ pad_token_id=tokenizer.eos_token_id,
118
+ )
119
+
120
+ # extract only generated part (strip the prompt tokens)
121
+ responses = []
122
+ for idx, g in enumerate(gen_ids):
123
+ g = g.tolist()
124
+ # find where prompt ends (simple approach)
125
+ inp_len = (batch_encoding["input_ids"][idx] == tokenizer.eos_token_id).nonzero(as_tuple=False)
126
+ # safer: just decode and remove prompt text
127
+ full_text = tokenizer.decode(g, skip_special_tokens=True)
128
+ # remove the prompt text prefix if present
129
+ prompt_text = batch_prompts[idx]
130
+ if full_text.startswith(prompt_text):
131
+ response_text = full_text[len(prompt_text):].strip()
132
+ else:
133
+ # fallback: token-level slicing using lengths
134
+ response_text = tokenizer.decode(g[len(batch_encoding["input_ids"][idx]):], skip_special_tokens=True)
135
+ responses.append(response_text)
136
+
137
+ # compute rewards (list of floats)
138
+ rewards = [dummy_reward(p, r) for p, r in zip(batch_prompts, responses)]
139
+
140
+ # convert responses to tokens for PPO step
141
+ # ppo_trainer.step expects raw prompts and generated tokens (or text), API may accept text directly
142
+ try:
143
+ # TRL commonly offers ppo_trainer.step(prompts, responses, rewards)
144
+ ppo_trainer.step(batch_prompts, responses, rewards)
145
+ except TypeError:
146
+ # Some TRL versions expect tokenized inputs; fallback to lower-level API
147
+ # Convert to list of token sequences
148
+ resp_token_ids = [tokenizer.encode(r) for r in responses]
149
+ ppo_trainer.step(batch_prompts, resp_token_ids, rewards)
150
+
151
+ if (i // BATCH_SIZE) % 10 == 0:
152
+ print(f" processed batch {i // BATCH_SIZE} - avg reward {sum(rewards)/len(rewards):.3f}")
153
+
154
+ # 5) Save final policy as safetensors
155
+ print("Saving final model (policy) to safetensors...")
156
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
157
+ # save_pretrained supports safe_serialization=True to write .safetensors
158
+ model.save_pretrained(OUTPUT_DIR, safe_serialization=True)
159
+ tokenizer.save_pretrained(OUTPUT_DIR)
160
+ print("Saved to", OUTPUT_DIR)
161
+ print("Done.")
162
+
163
+ if __name__ == "__main__":
164
+ main()