Upload 3 files
Browse files- mcts_text_gen.py +81 -0
- moe_mcts_new.pt +3 -0
- q_star.py +493 -0
mcts_text_gen.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from q_star import GPTWithMoE, GPTConfig, mcts_decode_single
|
3 |
+
|
4 |
+
|
5 |
+
def generate_text_with_mcts(
|
6 |
+
model: GPTWithMoE,
|
7 |
+
tokenizer, # Tokenizer to encode and decode text
|
8 |
+
prompt: str,
|
9 |
+
max_length: int = 50,
|
10 |
+
num_simulations: int = 100,
|
11 |
+
c_puct: float = 1.0,
|
12 |
+
top_k: int = 10,
|
13 |
+
device: str = "cuda"
|
14 |
+
):
|
15 |
+
"""
|
16 |
+
Generate text using the GPTWithMoE model and MCTS-based decoding.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
model (GPTWithMoE): The trained model.
|
20 |
+
tokenizer: The tokenizer for text encoding and decoding.
|
21 |
+
prompt (str): The initial text prompt.
|
22 |
+
max_length (int): Maximum length of the generated text.
|
23 |
+
num_simulations (int): Number of MCTS simulations for each decoding step.
|
24 |
+
c_puct (float): Exploration parameter for MCTS.
|
25 |
+
top_k (int): Top-k tokens to consider during MCTS expansion.
|
26 |
+
device (str): Device to use for computation.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
str: The generated text.
|
30 |
+
"""
|
31 |
+
model.eval()
|
32 |
+
model.to(device)
|
33 |
+
|
34 |
+
# Encode the prompt into input_ids
|
35 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
36 |
+
|
37 |
+
# Use MCTS to decode the sequence
|
38 |
+
generated_ids = mcts_decode_single(
|
39 |
+
model=model,
|
40 |
+
input_ids=input_ids,
|
41 |
+
max_length=max_length,
|
42 |
+
num_simulations=num_simulations,
|
43 |
+
c_puct=c_puct,
|
44 |
+
top_k=top_k,
|
45 |
+
)
|
46 |
+
|
47 |
+
# Decode the generated IDs back to text
|
48 |
+
generated_text = tokenizer.decode(generated_ids.tolist(), skip_special_tokens=True)
|
49 |
+
return generated_text
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
from transformers import GPT2Tokenizer
|
54 |
+
|
55 |
+
# Define the device
|
56 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
57 |
+
|
58 |
+
# Initialize the tokenizer (adapt as per your model's tokenizer)
|
59 |
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
60 |
+
tokenizer.pad_token = tokenizer.eos_token
|
61 |
+
|
62 |
+
# Load the trained model
|
63 |
+
config = GPTConfig(vocab_size=50304, block_size=512, n_layer=6, n_head=4, n_embd=256)
|
64 |
+
model = GPTWithMoE(config, num_experts=3, expert_layers=3, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device=device)
|
65 |
+
model.load_state_dict(torch.load("C:\\Users\\Admin\\MODELS\\moe_mcts_new.pt", map_location=device))
|
66 |
+
|
67 |
+
# Generate text using a prompt
|
68 |
+
prompt = "Once upon a time in a distant galaxy,"
|
69 |
+
generated_text = generate_text_with_mcts(
|
70 |
+
model=model,
|
71 |
+
tokenizer=tokenizer,
|
72 |
+
prompt=prompt,
|
73 |
+
max_length=100,
|
74 |
+
num_simulations=50,
|
75 |
+
c_puct=1.5,
|
76 |
+
top_k=5,
|
77 |
+
device=device,
|
78 |
+
)
|
79 |
+
|
80 |
+
print("Generated Text:")
|
81 |
+
print(generated_text)
|
moe_mcts_new.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:542e04c84ae8cdc6d30aa1e0bafba22a6150cc521e8e1ba302b1500aa9f77673
|
3 |
+
size 241844250
|
q_star.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from dataclasses import dataclass
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
import numpy as np
|
9 |
+
import time
|
10 |
+
from torch.utils.data import Dataset, DataLoader
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
13 |
+
import random
|
14 |
+
from collections import defaultdict
|
15 |
+
from torch.cuda.amp import autocast
|
16 |
+
from typing import List, Tuple
|
17 |
+
from torch.nn.utils.rnn import pad_sequence
|
18 |
+
import inspect
|
19 |
+
|
20 |
+
# Define your dataset and dataloader classes
|
21 |
+
class NpyDataset(Dataset):
|
22 |
+
def __init__(self, data_dir, file_prefix):
|
23 |
+
self.data_dir = data_dir
|
24 |
+
self.file_names = [os.path.join(data_dir, f) for f in sorted(os.listdir(data_dir)) if f.startswith(file_prefix) and f.endswith('.npy')]
|
25 |
+
|
26 |
+
def __len__(self):
|
27 |
+
return len(self.file_names)
|
28 |
+
|
29 |
+
def __getitem__(self, idx):
|
30 |
+
tokens_np = np.load(self.file_names[idx])
|
31 |
+
tokens_tensor = torch.tensor(tokens_np, dtype=torch.long)
|
32 |
+
return tokens_tensor
|
33 |
+
|
34 |
+
class CustomDataLoaderLite:
|
35 |
+
def __init__(self, dataset, batch_size, seq_len):
|
36 |
+
self.dataset = dataset
|
37 |
+
self.batch_size = batch_size
|
38 |
+
self.seq_len = seq_len
|
39 |
+
self.current_position = 0
|
40 |
+
|
41 |
+
def __iter__(self):
|
42 |
+
self.current_position = 0
|
43 |
+
return self
|
44 |
+
|
45 |
+
def __next__(self):
|
46 |
+
if self.current_position >= len(self.dataset):
|
47 |
+
raise StopIteration
|
48 |
+
|
49 |
+
batch = []
|
50 |
+
for _ in range(self.batch_size):
|
51 |
+
if self.current_position >= len(self.dataset):
|
52 |
+
break
|
53 |
+
tokens = self.dataset[self.current_position]
|
54 |
+
batch.append(tokens[:self.seq_len])
|
55 |
+
self.current_position += 1
|
56 |
+
|
57 |
+
x = torch.stack([tokens[:-1] for tokens in batch])
|
58 |
+
y = torch.stack([tokens[1:] for tokens in batch])
|
59 |
+
|
60 |
+
return x, y
|
61 |
+
|
62 |
+
def __len__(self):
|
63 |
+
return (len(self.dataset) + self.batch_size - 1) // self.batch_size
|
64 |
+
|
65 |
+
# Define the FlashAttention3 module
|
66 |
+
class FlashAttention3(nn.Module):
|
67 |
+
def __init__(self, d_model, n_heads, block_size_q, block_size_kv, num_blocks_kv, device='cuda'):
|
68 |
+
super(FlashAttention3, self).__init__()
|
69 |
+
self.d_model = d_model
|
70 |
+
self.n_heads = n_heads
|
71 |
+
self.block_size_q = block_size_q
|
72 |
+
self.block_size_kv = block_size_kv
|
73 |
+
self.num_blocks_kv = num_blocks_kv
|
74 |
+
self.device = device
|
75 |
+
|
76 |
+
self.q_proj = nn.Linear(d_model, d_model).to(device)
|
77 |
+
self.k_proj = nn.Linear(d_model, d_model).to(device)
|
78 |
+
self.v_proj = nn.Linear(d_model, d_model).to(device)
|
79 |
+
self.out_proj = nn.Linear(d_model, d_model).to(device)
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
B, T, C = x.size()
|
83 |
+
Q = self.q_proj(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
|
84 |
+
K = self.k_proj(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
|
85 |
+
V = self.v_proj(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
|
86 |
+
|
87 |
+
O = torch.zeros(B, self.n_heads, T, C // self.n_heads).to(self.device)
|
88 |
+
L = torch.zeros(B, self.n_heads, T).to(self.device)
|
89 |
+
M = torch.full((B, self.n_heads, T), -float('inf')).to(self.device)
|
90 |
+
|
91 |
+
for i in range(0, T, self.block_size_q):
|
92 |
+
Q_block = Q[:, :, i:i+self.block_size_q]
|
93 |
+
O_block = torch.zeros_like(Q_block).to(self.device)
|
94 |
+
L_block = torch.zeros(B, self.n_heads, Q_block.size(2)).to(self.device)
|
95 |
+
M_block = torch.full((B, self.n_heads, Q_block.size(2)), -float('inf')).to(self.device)
|
96 |
+
|
97 |
+
for j in range(0, T, self.block_size_kv):
|
98 |
+
K_block = K[:, :, j:j+self.block_size_kv]
|
99 |
+
V_block = V[:, :, j:j+self.block_size_kv]
|
100 |
+
|
101 |
+
S_block = torch.matmul(Q_block, K_block.transpose(-2, -1))
|
102 |
+
M_block_old = M_block
|
103 |
+
M_block = torch.max(M_block, S_block.max(dim=-1).values)
|
104 |
+
|
105 |
+
exp_S_block = torch.exp(S_block - M_block.unsqueeze(-1))
|
106 |
+
L_block = torch.exp(M_block_old - M_block) * L_block + exp_S_block.sum(dim=-1)
|
107 |
+
|
108 |
+
O_block += torch.matmul(exp_S_block, V_block)
|
109 |
+
|
110 |
+
O_block /= L_block.unsqueeze(-1)
|
111 |
+
O[:, :, i:i+self.block_size_q] = O_block
|
112 |
+
|
113 |
+
O = O.transpose(1, 2).contiguous().view(B, T, self.n_heads * (C // self.n_heads))
|
114 |
+
O = self.out_proj(O)
|
115 |
+
|
116 |
+
return O
|
117 |
+
|
118 |
+
# Define the MLP module
|
119 |
+
class MLP(nn.Module):
|
120 |
+
def __init__(self, config):
|
121 |
+
super().__init__()
|
122 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
|
123 |
+
self.gelu = nn.GELU(approximate='tanh')
|
124 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
|
125 |
+
self.dropout = nn.Dropout(config.dropout)
|
126 |
+
self.c_proj.scale_init = 1
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
x = self.c_fc(x)
|
130 |
+
x = self.gelu(x)
|
131 |
+
x = self.c_proj(x)
|
132 |
+
x = self.dropout(x)
|
133 |
+
return x
|
134 |
+
|
135 |
+
# Define the MixtureOfExperts module
|
136 |
+
class MixtureOfExperts(nn.Module):
|
137 |
+
def __init__(self, config, num_experts, expert_layers):
|
138 |
+
super().__init__()
|
139 |
+
self.num_experts = num_experts
|
140 |
+
self.expert_layers = expert_layers
|
141 |
+
|
142 |
+
self.experts = nn.ModuleList([self._create_expert(config) for _ in range(num_experts)])
|
143 |
+
self.gate = nn.Linear(config.n_embd, num_experts)
|
144 |
+
|
145 |
+
def _create_expert(self, config):
|
146 |
+
layers = []
|
147 |
+
for _ in range(self.expert_layers):
|
148 |
+
layers.append(FlashAttention3(d_model=config.n_embd, n_heads=config.n_head, block_size_q=32, block_size_kv=32, num_blocks_kv=4))
|
149 |
+
layers.append(nn.LayerNorm(config.n_embd))
|
150 |
+
layers.append(MLP(config))
|
151 |
+
return nn.Sequential(*layers)
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
B, T, C = x.size()
|
155 |
+
|
156 |
+
gate_scores = self.gate(x)
|
157 |
+
gate_probs = F.softmax(gate_scores, dim=-1)
|
158 |
+
|
159 |
+
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
|
160 |
+
|
161 |
+
gate_probs = gate_probs.unsqueeze(-1)
|
162 |
+
gate_probs = gate_probs.permute(0, 2, 1, 3)
|
163 |
+
|
164 |
+
output = torch.sum(gate_probs * expert_outputs, dim=1)
|
165 |
+
|
166 |
+
return output
|
167 |
+
|
168 |
+
# Define the BlockWithMoE module
|
169 |
+
class BlockWithMoE(nn.Module):
|
170 |
+
def __init__(self, config, num_experts=4, expert_layers=2, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device='cuda'):
|
171 |
+
super().__init__()
|
172 |
+
self.ln_1 = nn.LayerNorm(config.n_embd)
|
173 |
+
self.attn = FlashAttention3(d_model=config.n_embd, n_heads=config.n_head, block_size_q=block_size_q, block_size_kv=block_size_kv, num_blocks_kv=num_blocks_kv, device=device)
|
174 |
+
self.dropout1 = nn.Dropout(config.dropout)
|
175 |
+
self.ln_2 = nn.LayerNorm(config.n_embd)
|
176 |
+
self.moe = MixtureOfExperts(config, num_experts, expert_layers)
|
177 |
+
self.dropout2 = nn.Dropout(config.dropout)
|
178 |
+
self.ln_3 = nn.LayerNorm(config.n_embd)
|
179 |
+
self.mlp = MLP(config)
|
180 |
+
self.dropout3 = nn.Dropout(config.dropout)
|
181 |
+
|
182 |
+
def forward(self, x):
|
183 |
+
B, T, C = x.size()
|
184 |
+
|
185 |
+
attn_output = self.attn(x)
|
186 |
+
x = x + attn_output
|
187 |
+
x = self.dropout1(x)
|
188 |
+
x = x + self.moe(self.ln_2(x))
|
189 |
+
x = self.dropout2(x)
|
190 |
+
x = x + self.mlp(self.ln_3(x))
|
191 |
+
x = self.dropout3(x)
|
192 |
+
return x
|
193 |
+
|
194 |
+
# Define the GPT configuration dataclass
|
195 |
+
@dataclass
|
196 |
+
class GPTConfig:
|
197 |
+
block_size: int = 512
|
198 |
+
vocab_size: int = 50257
|
199 |
+
n_layer: int = 6
|
200 |
+
n_head: int = 4
|
201 |
+
n_embd: int = 256
|
202 |
+
dropout: float = 0.2
|
203 |
+
|
204 |
+
# Define the GPTWithMoE model
|
205 |
+
class GPTWithMoE(nn.Module):
|
206 |
+
def __init__(self, config, num_experts=2, expert_layers=2, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device='cuda'):
|
207 |
+
super().__init__()
|
208 |
+
self.config = config
|
209 |
+
|
210 |
+
self.transformer = nn.ModuleDict(dict(
|
211 |
+
wte=nn.Embedding(config.vocab_size, config.n_embd),
|
212 |
+
wpe=nn.Embedding(config.block_size, config.n_embd),
|
213 |
+
h=nn.ModuleList([BlockWithMoE(config, num_experts, expert_layers, block_size_q, block_size_kv, num_blocks_kv, device) for _ in range(config.n_layer)]),
|
214 |
+
ln_f=nn.LayerNorm(config.n_embd),
|
215 |
+
))
|
216 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
217 |
+
|
218 |
+
self.transformer.wte.weight = self.lm_head.weight
|
219 |
+
|
220 |
+
self.apply(self._init_weights)
|
221 |
+
|
222 |
+
def _init_weights(self, module):
|
223 |
+
if isinstance(module, nn.Linear):
|
224 |
+
std = 0.02
|
225 |
+
if hasattr(module, 'scale_init'):
|
226 |
+
std *= (2 * self.config.n_layer) ** -0.5
|
227 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
228 |
+
if module.bias is not None:
|
229 |
+
torch.nn.init.zeros_(module.bias)
|
230 |
+
elif isinstance(module, nn.Embedding):
|
231 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
232 |
+
|
233 |
+
def forward(self, idx, targets=None):
|
234 |
+
B, T = idx.size()
|
235 |
+
assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
|
236 |
+
pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
|
237 |
+
pos_emb = self.transformer.wpe(pos)
|
238 |
+
tok_emb = self.transformer.wte(idx)
|
239 |
+
x = tok_emb + pos_emb
|
240 |
+
for block in self.transformer.h:
|
241 |
+
x = block(x)
|
242 |
+
x = self.transformer.ln_f(x)
|
243 |
+
logits = self.lm_head(x)
|
244 |
+
loss = None
|
245 |
+
if targets is not None:
|
246 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
247 |
+
return logits, loss
|
248 |
+
|
249 |
+
def configure_optimizers(self, weight_decay, learning_rate, device):
|
250 |
+
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
|
251 |
+
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
252 |
+
non_decay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
253 |
+
|
254 |
+
optim_groups = [
|
255 |
+
{'params': decay_params, 'weight_decay': weight_decay},
|
256 |
+
{'params': non_decay_params, 'weight_decay': 0}
|
257 |
+
]
|
258 |
+
|
259 |
+
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
260 |
+
use_fused = fused_available and 'cuda' in device
|
261 |
+
print(f" Using fused AdamW: {use_fused}")
|
262 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
|
263 |
+
return optimizer
|
264 |
+
|
265 |
+
# MCTS Implementation
|
266 |
+
@dataclass
|
267 |
+
class MCTSNode:
|
268 |
+
state: torch.Tensor
|
269 |
+
parent: 'MCTSNode' = None
|
270 |
+
children: dict = None
|
271 |
+
visits: int = 0
|
272 |
+
value: float = 0.0
|
273 |
+
|
274 |
+
def __post_init__(self):
|
275 |
+
if self.children is None:
|
276 |
+
self.children = {}
|
277 |
+
|
278 |
+
# Define scriptable functions separately
|
279 |
+
def select_node(node: MCTSNode, c_puct: float) -> MCTSNode:
|
280 |
+
if not node.children:
|
281 |
+
return node
|
282 |
+
|
283 |
+
scores = torch.tensor([
|
284 |
+
child.value / (child.visits + 1e-8) +
|
285 |
+
c_puct * math.sqrt(math.log(node.visits + 1) / (child.visits + 1e-8))
|
286 |
+
for child in node.children.values()
|
287 |
+
])
|
288 |
+
|
289 |
+
best_child_idx = torch.argmax(scores).item()
|
290 |
+
return list(node.children.values())[best_child_idx]
|
291 |
+
|
292 |
+
def expand_node(node: MCTSNode, logits: torch.Tensor, top_k: int) -> None:
|
293 |
+
probs = F.softmax(logits, dim=-1)
|
294 |
+
top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
|
295 |
+
|
296 |
+
for prob, token in zip(top_k_probs, top_k_indices):
|
297 |
+
if token.item() not in node.children:
|
298 |
+
node.children[token.item()] = MCTSNode(state=token, parent=node)
|
299 |
+
|
300 |
+
def simulate(model: torch.nn.Module, sequence: torch.Tensor, max_length: int) -> torch.Tensor:
|
301 |
+
# Ensure sequence is 2D
|
302 |
+
if sequence.dim() == 1:
|
303 |
+
sequence = sequence.unsqueeze(0)
|
304 |
+
|
305 |
+
with torch.no_grad():
|
306 |
+
while sequence.size(1) < max_length:
|
307 |
+
with autocast():
|
308 |
+
logits, _ = model(sequence)
|
309 |
+
probs = F.softmax(logits[0, -1], dim=-1)
|
310 |
+
next_token = torch.multinomial(probs, 1)
|
311 |
+
sequence = torch.cat([sequence, next_token.unsqueeze(0)], dim=1)
|
312 |
+
return sequence.squeeze(0)
|
313 |
+
|
314 |
+
def backpropagate(node: MCTSNode, value: float) -> None:
|
315 |
+
while node is not None:
|
316 |
+
node.visits += 1
|
317 |
+
node.value += value
|
318 |
+
node = node.parent
|
319 |
+
|
320 |
+
def mcts_decode_single(model: torch.nn.Module, input_ids: torch.Tensor, max_length: int, num_simulations: int, c_puct: float, top_k: int) -> torch.Tensor:
|
321 |
+
# Ensure input_ids is 2D
|
322 |
+
if input_ids.dim() == 1:
|
323 |
+
input_ids = input_ids.unsqueeze(0)
|
324 |
+
|
325 |
+
root = MCTSNode(state=input_ids)
|
326 |
+
|
327 |
+
for _ in range(num_simulations):
|
328 |
+
node = root
|
329 |
+
current_input = input_ids.clone()
|
330 |
+
|
331 |
+
# Selection
|
332 |
+
while node.children and current_input.size(1) < max_length:
|
333 |
+
node = select_node(node, c_puct)
|
334 |
+
current_input = torch.cat([current_input, node.state.unsqueeze(0).unsqueeze(0)], dim=1)
|
335 |
+
|
336 |
+
# Expansion
|
337 |
+
if current_input.size(1) < max_length:
|
338 |
+
with torch.no_grad():
|
339 |
+
with autocast():
|
340 |
+
logits, _ = model(current_input)
|
341 |
+
expand_node(node, logits[0, -1], top_k)
|
342 |
+
|
343 |
+
# Simulation
|
344 |
+
simulation_sequence = simulate(model, current_input.squeeze(0), max_length)
|
345 |
+
|
346 |
+
# Evaluation
|
347 |
+
with torch.no_grad():
|
348 |
+
with autocast():
|
349 |
+
_, loss = model(simulation_sequence.unsqueeze(0), simulation_sequence.unsqueeze(0))
|
350 |
+
value = -loss.item()
|
351 |
+
|
352 |
+
# Backpropagation
|
353 |
+
backpropagate(node, value)
|
354 |
+
|
355 |
+
# Choose the best next token
|
356 |
+
best_child = max(root.children.values(), key=lambda n: n.visits)
|
357 |
+
result = torch.cat([input_ids.squeeze(0), best_child.state.unsqueeze(0)], dim=0)
|
358 |
+
|
359 |
+
# Ensure the result doesn't exceed max_length
|
360 |
+
return result[:max_length]
|
361 |
+
|
362 |
+
def mcts_decode_batch(model: torch.nn.Module, input_ids_list: List[torch.Tensor], max_length: int, num_simulations: int, c_puct: float, top_k: int) -> List[torch.Tensor]:
|
363 |
+
return [mcts_decode_single(model, input_ids.unsqueeze(0) if input_ids.dim() == 1 else input_ids, max_length, num_simulations, c_puct, top_k) for input_ids in input_ids_list]
|
364 |
+
|
365 |
+
def validate_with_mcts(model: torch.nn.Module, val_dataloader: CustomDataLoaderLite, device: torch.device, max_length: int, num_simulations: int, c_puct: float, top_k: int) -> float:
|
366 |
+
model.eval()
|
367 |
+
total_loss = 0.0
|
368 |
+
num_batches = 0
|
369 |
+
|
370 |
+
with torch.no_grad():
|
371 |
+
for x, y in val_dataloader:
|
372 |
+
x, y = x.to(device), y.to(device)
|
373 |
+
|
374 |
+
# Use MCTS for decoding
|
375 |
+
decoded_sequences = mcts_decode_batch(model, x, max_length, num_simulations, c_puct, top_k)
|
376 |
+
|
377 |
+
# Pad sequences to the same length
|
378 |
+
decoded_sequences_padded = pad_sequence(decoded_sequences, batch_first=True, padding_value=0)
|
379 |
+
|
380 |
+
# Trim the decoded sequences to match the target length
|
381 |
+
decoded_sequences_trimmed = decoded_sequences_padded[:, :y.size(1)]
|
382 |
+
|
383 |
+
# Calculate loss using the MCTS-decoded sequences
|
384 |
+
with autocast():
|
385 |
+
logits, loss = model(decoded_sequences_trimmed, y)
|
386 |
+
total_loss += loss.item()
|
387 |
+
num_batches += 1
|
388 |
+
|
389 |
+
return total_loss / num_batches if num_batches > 0 else 0.0
|
390 |
+
|
391 |
+
def train_model():
|
392 |
+
device = 'cpu'
|
393 |
+
if torch.cuda.is_available():
|
394 |
+
device = 'cuda'
|
395 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
396 |
+
device = 'mps'
|
397 |
+
print(f"using device : {device}")
|
398 |
+
|
399 |
+
# Load the dataset and create the data loader
|
400 |
+
print("Loading datasets...")
|
401 |
+
train_dataset = NpyDataset('edu_fineweb10B', 'edufineweb_train')
|
402 |
+
val_dataset = NpyDataset('edu_fineweb10B', 'edufineweb_val')
|
403 |
+
train_dataloader = CustomDataLoaderLite(train_dataset, batch_size=12, seq_len=512)
|
404 |
+
val_dataloader = CustomDataLoaderLite(val_dataset, batch_size=12, seq_len=512)
|
405 |
+
|
406 |
+
# Training loop
|
407 |
+
max_steps = 200
|
408 |
+
total_batch_size = 262144
|
409 |
+
B = 12
|
410 |
+
T = 512
|
411 |
+
grad_accum_steps = total_batch_size // (B * T)
|
412 |
+
|
413 |
+
# Set up the configuration
|
414 |
+
print("Setting up model configuration...")
|
415 |
+
config = GPTConfig(vocab_size=50304, block_size=512, n_layer=6, n_head=4, n_embd=256)
|
416 |
+
|
417 |
+
# Initialize the model
|
418 |
+
print("Initializing model...")
|
419 |
+
model = GPTWithMoE(config, num_experts=3, expert_layers=3, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device=device)
|
420 |
+
model.to(device)
|
421 |
+
|
422 |
+
# Load the saved model weights if they exist
|
423 |
+
save_path = "C:\\Users\\Admin\\MODELS\\moe_mcts_new.pt"
|
424 |
+
temp_save_path = "C:\\Users\\Admin\\MODELS\\moe_mcts_temp_new.pt"
|
425 |
+
if os.path.isfile(save_path):
|
426 |
+
print(f"Loading model weights from {save_path}...")
|
427 |
+
model.load_state_dict(torch.load(save_path))
|
428 |
+
print(f"Loaded model weights from {save_path}")
|
429 |
+
|
430 |
+
print("Configuring optimizer...")
|
431 |
+
optimizer = model.configure_optimizers(weight_decay=0.2, learning_rate=3e-3, device=device)
|
432 |
+
|
433 |
+
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
|
434 |
+
|
435 |
+
train_losses = []
|
436 |
+
val_losses = []
|
437 |
+
|
438 |
+
scaler = torch.cuda.amp.GradScaler()
|
439 |
+
|
440 |
+
for i in range(max_steps):
|
441 |
+
t0 = time.time()
|
442 |
+
optimizer.zero_grad()
|
443 |
+
train_loss_accum = 0
|
444 |
+
|
445 |
+
model.train()
|
446 |
+
print(f"Training step {i + 1}/{max_steps}...")
|
447 |
+
for x, y in train_dataloader:
|
448 |
+
x, y = x.to(device), y.to(device)
|
449 |
+
with torch.cuda.amp.autocast():
|
450 |
+
logits, loss = model(x, y)
|
451 |
+
loss = loss / grad_accum_steps
|
452 |
+
train_loss_accum += loss.detach()
|
453 |
+
scaler.scale(loss).backward()
|
454 |
+
|
455 |
+
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
456 |
+
scaler.step(optimizer)
|
457 |
+
scaler.update()
|
458 |
+
|
459 |
+
torch.cuda.synchronize()
|
460 |
+
t1 = time.time()
|
461 |
+
dt = (t1 - t0) * 1000
|
462 |
+
tokens_per_sec = (B * T * grad_accum_steps) / (t1 - t0)
|
463 |
+
train_losses.append(train_loss_accum.item())
|
464 |
+
|
465 |
+
torch.cuda.empty_cache()
|
466 |
+
|
467 |
+
# Validation with MCTS
|
468 |
+
model.eval()
|
469 |
+
val_loss = validate_with_mcts(model, val_dataloader, device, max_length=T, num_simulations=100, c_puct=1.0, top_k=10)
|
470 |
+
val_losses.append(val_loss)
|
471 |
+
|
472 |
+
scheduler.step(val_loss)
|
473 |
+
|
474 |
+
print(f"step {i} | train loss: {train_loss_accum.item():.6f} | val loss: {val_loss:.6f} | lr: {optimizer.param_groups[0]['lr']:.8f} | norm: {norm:.4f} | dt: {dt:.2f}ms | tok/sec: {tokens_per_sec}")
|
475 |
+
|
476 |
+
# Save model weights
|
477 |
+
torch.save(model.state_dict(), temp_save_path)
|
478 |
+
os.replace(temp_save_path, save_path)
|
479 |
+
print(f"Model saved at step {i+1} to {save_path}")
|
480 |
+
|
481 |
+
# Plotting the training and validation loss
|
482 |
+
plt.figure(figsize=(10, 5))
|
483 |
+
plt.plot(train_losses, label='Training Loss')
|
484 |
+
plt.plot(val_losses, label='Validation Loss')
|
485 |
+
plt.xlabel('Steps')
|
486 |
+
plt.ylabel('Loss')
|
487 |
+
plt.legend()
|
488 |
+
plt.show()
|
489 |
+
|
490 |
+
if __name__ == "__main__":
|
491 |
+
train_model()
|
492 |
+
|
493 |
+
|