RobbiePasquale commited on
Commit
8e083dc
·
verified ·
1 Parent(s): 5c634e7

Upload 3 files

Browse files
Files changed (3) hide show
  1. mcts_text_gen.py +81 -0
  2. moe_mcts_new.pt +3 -0
  3. q_star.py +493 -0
mcts_text_gen.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from q_star import GPTWithMoE, GPTConfig, mcts_decode_single
3
+
4
+
5
+ def generate_text_with_mcts(
6
+ model: GPTWithMoE,
7
+ tokenizer, # Tokenizer to encode and decode text
8
+ prompt: str,
9
+ max_length: int = 50,
10
+ num_simulations: int = 100,
11
+ c_puct: float = 1.0,
12
+ top_k: int = 10,
13
+ device: str = "cuda"
14
+ ):
15
+ """
16
+ Generate text using the GPTWithMoE model and MCTS-based decoding.
17
+
18
+ Args:
19
+ model (GPTWithMoE): The trained model.
20
+ tokenizer: The tokenizer for text encoding and decoding.
21
+ prompt (str): The initial text prompt.
22
+ max_length (int): Maximum length of the generated text.
23
+ num_simulations (int): Number of MCTS simulations for each decoding step.
24
+ c_puct (float): Exploration parameter for MCTS.
25
+ top_k (int): Top-k tokens to consider during MCTS expansion.
26
+ device (str): Device to use for computation.
27
+
28
+ Returns:
29
+ str: The generated text.
30
+ """
31
+ model.eval()
32
+ model.to(device)
33
+
34
+ # Encode the prompt into input_ids
35
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
36
+
37
+ # Use MCTS to decode the sequence
38
+ generated_ids = mcts_decode_single(
39
+ model=model,
40
+ input_ids=input_ids,
41
+ max_length=max_length,
42
+ num_simulations=num_simulations,
43
+ c_puct=c_puct,
44
+ top_k=top_k,
45
+ )
46
+
47
+ # Decode the generated IDs back to text
48
+ generated_text = tokenizer.decode(generated_ids.tolist(), skip_special_tokens=True)
49
+ return generated_text
50
+
51
+
52
+ if __name__ == "__main__":
53
+ from transformers import GPT2Tokenizer
54
+
55
+ # Define the device
56
+ device = "cuda" if torch.cuda.is_available() else "cpu"
57
+
58
+ # Initialize the tokenizer (adapt as per your model's tokenizer)
59
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
60
+ tokenizer.pad_token = tokenizer.eos_token
61
+
62
+ # Load the trained model
63
+ config = GPTConfig(vocab_size=50304, block_size=512, n_layer=6, n_head=4, n_embd=256)
64
+ model = GPTWithMoE(config, num_experts=3, expert_layers=3, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device=device)
65
+ model.load_state_dict(torch.load("C:\\Users\\Admin\\MODELS\\moe_mcts_new.pt", map_location=device))
66
+
67
+ # Generate text using a prompt
68
+ prompt = "Once upon a time in a distant galaxy,"
69
+ generated_text = generate_text_with_mcts(
70
+ model=model,
71
+ tokenizer=tokenizer,
72
+ prompt=prompt,
73
+ max_length=100,
74
+ num_simulations=50,
75
+ c_puct=1.5,
76
+ top_k=5,
77
+ device=device,
78
+ )
79
+
80
+ print("Generated Text:")
81
+ print(generated_text)
moe_mcts_new.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:542e04c84ae8cdc6d30aa1e0bafba22a6150cc521e8e1ba302b1500aa9f77673
3
+ size 241844250
q_star.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from dataclasses import dataclass
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ import math
7
+ import os
8
+ import numpy as np
9
+ import time
10
+ from torch.utils.data import Dataset, DataLoader
11
+ import matplotlib.pyplot as plt
12
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
13
+ import random
14
+ from collections import defaultdict
15
+ from torch.cuda.amp import autocast
16
+ from typing import List, Tuple
17
+ from torch.nn.utils.rnn import pad_sequence
18
+ import inspect
19
+
20
+ # Define your dataset and dataloader classes
21
+ class NpyDataset(Dataset):
22
+ def __init__(self, data_dir, file_prefix):
23
+ self.data_dir = data_dir
24
+ self.file_names = [os.path.join(data_dir, f) for f in sorted(os.listdir(data_dir)) if f.startswith(file_prefix) and f.endswith('.npy')]
25
+
26
+ def __len__(self):
27
+ return len(self.file_names)
28
+
29
+ def __getitem__(self, idx):
30
+ tokens_np = np.load(self.file_names[idx])
31
+ tokens_tensor = torch.tensor(tokens_np, dtype=torch.long)
32
+ return tokens_tensor
33
+
34
+ class CustomDataLoaderLite:
35
+ def __init__(self, dataset, batch_size, seq_len):
36
+ self.dataset = dataset
37
+ self.batch_size = batch_size
38
+ self.seq_len = seq_len
39
+ self.current_position = 0
40
+
41
+ def __iter__(self):
42
+ self.current_position = 0
43
+ return self
44
+
45
+ def __next__(self):
46
+ if self.current_position >= len(self.dataset):
47
+ raise StopIteration
48
+
49
+ batch = []
50
+ for _ in range(self.batch_size):
51
+ if self.current_position >= len(self.dataset):
52
+ break
53
+ tokens = self.dataset[self.current_position]
54
+ batch.append(tokens[:self.seq_len])
55
+ self.current_position += 1
56
+
57
+ x = torch.stack([tokens[:-1] for tokens in batch])
58
+ y = torch.stack([tokens[1:] for tokens in batch])
59
+
60
+ return x, y
61
+
62
+ def __len__(self):
63
+ return (len(self.dataset) + self.batch_size - 1) // self.batch_size
64
+
65
+ # Define the FlashAttention3 module
66
+ class FlashAttention3(nn.Module):
67
+ def __init__(self, d_model, n_heads, block_size_q, block_size_kv, num_blocks_kv, device='cuda'):
68
+ super(FlashAttention3, self).__init__()
69
+ self.d_model = d_model
70
+ self.n_heads = n_heads
71
+ self.block_size_q = block_size_q
72
+ self.block_size_kv = block_size_kv
73
+ self.num_blocks_kv = num_blocks_kv
74
+ self.device = device
75
+
76
+ self.q_proj = nn.Linear(d_model, d_model).to(device)
77
+ self.k_proj = nn.Linear(d_model, d_model).to(device)
78
+ self.v_proj = nn.Linear(d_model, d_model).to(device)
79
+ self.out_proj = nn.Linear(d_model, d_model).to(device)
80
+
81
+ def forward(self, x):
82
+ B, T, C = x.size()
83
+ Q = self.q_proj(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
84
+ K = self.k_proj(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
85
+ V = self.v_proj(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
86
+
87
+ O = torch.zeros(B, self.n_heads, T, C // self.n_heads).to(self.device)
88
+ L = torch.zeros(B, self.n_heads, T).to(self.device)
89
+ M = torch.full((B, self.n_heads, T), -float('inf')).to(self.device)
90
+
91
+ for i in range(0, T, self.block_size_q):
92
+ Q_block = Q[:, :, i:i+self.block_size_q]
93
+ O_block = torch.zeros_like(Q_block).to(self.device)
94
+ L_block = torch.zeros(B, self.n_heads, Q_block.size(2)).to(self.device)
95
+ M_block = torch.full((B, self.n_heads, Q_block.size(2)), -float('inf')).to(self.device)
96
+
97
+ for j in range(0, T, self.block_size_kv):
98
+ K_block = K[:, :, j:j+self.block_size_kv]
99
+ V_block = V[:, :, j:j+self.block_size_kv]
100
+
101
+ S_block = torch.matmul(Q_block, K_block.transpose(-2, -1))
102
+ M_block_old = M_block
103
+ M_block = torch.max(M_block, S_block.max(dim=-1).values)
104
+
105
+ exp_S_block = torch.exp(S_block - M_block.unsqueeze(-1))
106
+ L_block = torch.exp(M_block_old - M_block) * L_block + exp_S_block.sum(dim=-1)
107
+
108
+ O_block += torch.matmul(exp_S_block, V_block)
109
+
110
+ O_block /= L_block.unsqueeze(-1)
111
+ O[:, :, i:i+self.block_size_q] = O_block
112
+
113
+ O = O.transpose(1, 2).contiguous().view(B, T, self.n_heads * (C // self.n_heads))
114
+ O = self.out_proj(O)
115
+
116
+ return O
117
+
118
+ # Define the MLP module
119
+ class MLP(nn.Module):
120
+ def __init__(self, config):
121
+ super().__init__()
122
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
123
+ self.gelu = nn.GELU(approximate='tanh')
124
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
125
+ self.dropout = nn.Dropout(config.dropout)
126
+ self.c_proj.scale_init = 1
127
+
128
+ def forward(self, x):
129
+ x = self.c_fc(x)
130
+ x = self.gelu(x)
131
+ x = self.c_proj(x)
132
+ x = self.dropout(x)
133
+ return x
134
+
135
+ # Define the MixtureOfExperts module
136
+ class MixtureOfExperts(nn.Module):
137
+ def __init__(self, config, num_experts, expert_layers):
138
+ super().__init__()
139
+ self.num_experts = num_experts
140
+ self.expert_layers = expert_layers
141
+
142
+ self.experts = nn.ModuleList([self._create_expert(config) for _ in range(num_experts)])
143
+ self.gate = nn.Linear(config.n_embd, num_experts)
144
+
145
+ def _create_expert(self, config):
146
+ layers = []
147
+ for _ in range(self.expert_layers):
148
+ layers.append(FlashAttention3(d_model=config.n_embd, n_heads=config.n_head, block_size_q=32, block_size_kv=32, num_blocks_kv=4))
149
+ layers.append(nn.LayerNorm(config.n_embd))
150
+ layers.append(MLP(config))
151
+ return nn.Sequential(*layers)
152
+
153
+ def forward(self, x):
154
+ B, T, C = x.size()
155
+
156
+ gate_scores = self.gate(x)
157
+ gate_probs = F.softmax(gate_scores, dim=-1)
158
+
159
+ expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
160
+
161
+ gate_probs = gate_probs.unsqueeze(-1)
162
+ gate_probs = gate_probs.permute(0, 2, 1, 3)
163
+
164
+ output = torch.sum(gate_probs * expert_outputs, dim=1)
165
+
166
+ return output
167
+
168
+ # Define the BlockWithMoE module
169
+ class BlockWithMoE(nn.Module):
170
+ def __init__(self, config, num_experts=4, expert_layers=2, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device='cuda'):
171
+ super().__init__()
172
+ self.ln_1 = nn.LayerNorm(config.n_embd)
173
+ self.attn = FlashAttention3(d_model=config.n_embd, n_heads=config.n_head, block_size_q=block_size_q, block_size_kv=block_size_kv, num_blocks_kv=num_blocks_kv, device=device)
174
+ self.dropout1 = nn.Dropout(config.dropout)
175
+ self.ln_2 = nn.LayerNorm(config.n_embd)
176
+ self.moe = MixtureOfExperts(config, num_experts, expert_layers)
177
+ self.dropout2 = nn.Dropout(config.dropout)
178
+ self.ln_3 = nn.LayerNorm(config.n_embd)
179
+ self.mlp = MLP(config)
180
+ self.dropout3 = nn.Dropout(config.dropout)
181
+
182
+ def forward(self, x):
183
+ B, T, C = x.size()
184
+
185
+ attn_output = self.attn(x)
186
+ x = x + attn_output
187
+ x = self.dropout1(x)
188
+ x = x + self.moe(self.ln_2(x))
189
+ x = self.dropout2(x)
190
+ x = x + self.mlp(self.ln_3(x))
191
+ x = self.dropout3(x)
192
+ return x
193
+
194
+ # Define the GPT configuration dataclass
195
+ @dataclass
196
+ class GPTConfig:
197
+ block_size: int = 512
198
+ vocab_size: int = 50257
199
+ n_layer: int = 6
200
+ n_head: int = 4
201
+ n_embd: int = 256
202
+ dropout: float = 0.2
203
+
204
+ # Define the GPTWithMoE model
205
+ class GPTWithMoE(nn.Module):
206
+ def __init__(self, config, num_experts=2, expert_layers=2, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device='cuda'):
207
+ super().__init__()
208
+ self.config = config
209
+
210
+ self.transformer = nn.ModuleDict(dict(
211
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
212
+ wpe=nn.Embedding(config.block_size, config.n_embd),
213
+ h=nn.ModuleList([BlockWithMoE(config, num_experts, expert_layers, block_size_q, block_size_kv, num_blocks_kv, device) for _ in range(config.n_layer)]),
214
+ ln_f=nn.LayerNorm(config.n_embd),
215
+ ))
216
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
217
+
218
+ self.transformer.wte.weight = self.lm_head.weight
219
+
220
+ self.apply(self._init_weights)
221
+
222
+ def _init_weights(self, module):
223
+ if isinstance(module, nn.Linear):
224
+ std = 0.02
225
+ if hasattr(module, 'scale_init'):
226
+ std *= (2 * self.config.n_layer) ** -0.5
227
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
228
+ if module.bias is not None:
229
+ torch.nn.init.zeros_(module.bias)
230
+ elif isinstance(module, nn.Embedding):
231
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
232
+
233
+ def forward(self, idx, targets=None):
234
+ B, T = idx.size()
235
+ assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
236
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
237
+ pos_emb = self.transformer.wpe(pos)
238
+ tok_emb = self.transformer.wte(idx)
239
+ x = tok_emb + pos_emb
240
+ for block in self.transformer.h:
241
+ x = block(x)
242
+ x = self.transformer.ln_f(x)
243
+ logits = self.lm_head(x)
244
+ loss = None
245
+ if targets is not None:
246
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
247
+ return logits, loss
248
+
249
+ def configure_optimizers(self, weight_decay, learning_rate, device):
250
+ param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
251
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
252
+ non_decay_params = [p for n, p in param_dict.items() if p.dim() < 2]
253
+
254
+ optim_groups = [
255
+ {'params': decay_params, 'weight_decay': weight_decay},
256
+ {'params': non_decay_params, 'weight_decay': 0}
257
+ ]
258
+
259
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
260
+ use_fused = fused_available and 'cuda' in device
261
+ print(f" Using fused AdamW: {use_fused}")
262
+ optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
263
+ return optimizer
264
+
265
+ # MCTS Implementation
266
+ @dataclass
267
+ class MCTSNode:
268
+ state: torch.Tensor
269
+ parent: 'MCTSNode' = None
270
+ children: dict = None
271
+ visits: int = 0
272
+ value: float = 0.0
273
+
274
+ def __post_init__(self):
275
+ if self.children is None:
276
+ self.children = {}
277
+
278
+ # Define scriptable functions separately
279
+ def select_node(node: MCTSNode, c_puct: float) -> MCTSNode:
280
+ if not node.children:
281
+ return node
282
+
283
+ scores = torch.tensor([
284
+ child.value / (child.visits + 1e-8) +
285
+ c_puct * math.sqrt(math.log(node.visits + 1) / (child.visits + 1e-8))
286
+ for child in node.children.values()
287
+ ])
288
+
289
+ best_child_idx = torch.argmax(scores).item()
290
+ return list(node.children.values())[best_child_idx]
291
+
292
+ def expand_node(node: MCTSNode, logits: torch.Tensor, top_k: int) -> None:
293
+ probs = F.softmax(logits, dim=-1)
294
+ top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
295
+
296
+ for prob, token in zip(top_k_probs, top_k_indices):
297
+ if token.item() not in node.children:
298
+ node.children[token.item()] = MCTSNode(state=token, parent=node)
299
+
300
+ def simulate(model: torch.nn.Module, sequence: torch.Tensor, max_length: int) -> torch.Tensor:
301
+ # Ensure sequence is 2D
302
+ if sequence.dim() == 1:
303
+ sequence = sequence.unsqueeze(0)
304
+
305
+ with torch.no_grad():
306
+ while sequence.size(1) < max_length:
307
+ with autocast():
308
+ logits, _ = model(sequence)
309
+ probs = F.softmax(logits[0, -1], dim=-1)
310
+ next_token = torch.multinomial(probs, 1)
311
+ sequence = torch.cat([sequence, next_token.unsqueeze(0)], dim=1)
312
+ return sequence.squeeze(0)
313
+
314
+ def backpropagate(node: MCTSNode, value: float) -> None:
315
+ while node is not None:
316
+ node.visits += 1
317
+ node.value += value
318
+ node = node.parent
319
+
320
+ def mcts_decode_single(model: torch.nn.Module, input_ids: torch.Tensor, max_length: int, num_simulations: int, c_puct: float, top_k: int) -> torch.Tensor:
321
+ # Ensure input_ids is 2D
322
+ if input_ids.dim() == 1:
323
+ input_ids = input_ids.unsqueeze(0)
324
+
325
+ root = MCTSNode(state=input_ids)
326
+
327
+ for _ in range(num_simulations):
328
+ node = root
329
+ current_input = input_ids.clone()
330
+
331
+ # Selection
332
+ while node.children and current_input.size(1) < max_length:
333
+ node = select_node(node, c_puct)
334
+ current_input = torch.cat([current_input, node.state.unsqueeze(0).unsqueeze(0)], dim=1)
335
+
336
+ # Expansion
337
+ if current_input.size(1) < max_length:
338
+ with torch.no_grad():
339
+ with autocast():
340
+ logits, _ = model(current_input)
341
+ expand_node(node, logits[0, -1], top_k)
342
+
343
+ # Simulation
344
+ simulation_sequence = simulate(model, current_input.squeeze(0), max_length)
345
+
346
+ # Evaluation
347
+ with torch.no_grad():
348
+ with autocast():
349
+ _, loss = model(simulation_sequence.unsqueeze(0), simulation_sequence.unsqueeze(0))
350
+ value = -loss.item()
351
+
352
+ # Backpropagation
353
+ backpropagate(node, value)
354
+
355
+ # Choose the best next token
356
+ best_child = max(root.children.values(), key=lambda n: n.visits)
357
+ result = torch.cat([input_ids.squeeze(0), best_child.state.unsqueeze(0)], dim=0)
358
+
359
+ # Ensure the result doesn't exceed max_length
360
+ return result[:max_length]
361
+
362
+ def mcts_decode_batch(model: torch.nn.Module, input_ids_list: List[torch.Tensor], max_length: int, num_simulations: int, c_puct: float, top_k: int) -> List[torch.Tensor]:
363
+ return [mcts_decode_single(model, input_ids.unsqueeze(0) if input_ids.dim() == 1 else input_ids, max_length, num_simulations, c_puct, top_k) for input_ids in input_ids_list]
364
+
365
+ def validate_with_mcts(model: torch.nn.Module, val_dataloader: CustomDataLoaderLite, device: torch.device, max_length: int, num_simulations: int, c_puct: float, top_k: int) -> float:
366
+ model.eval()
367
+ total_loss = 0.0
368
+ num_batches = 0
369
+
370
+ with torch.no_grad():
371
+ for x, y in val_dataloader:
372
+ x, y = x.to(device), y.to(device)
373
+
374
+ # Use MCTS for decoding
375
+ decoded_sequences = mcts_decode_batch(model, x, max_length, num_simulations, c_puct, top_k)
376
+
377
+ # Pad sequences to the same length
378
+ decoded_sequences_padded = pad_sequence(decoded_sequences, batch_first=True, padding_value=0)
379
+
380
+ # Trim the decoded sequences to match the target length
381
+ decoded_sequences_trimmed = decoded_sequences_padded[:, :y.size(1)]
382
+
383
+ # Calculate loss using the MCTS-decoded sequences
384
+ with autocast():
385
+ logits, loss = model(decoded_sequences_trimmed, y)
386
+ total_loss += loss.item()
387
+ num_batches += 1
388
+
389
+ return total_loss / num_batches if num_batches > 0 else 0.0
390
+
391
+ def train_model():
392
+ device = 'cpu'
393
+ if torch.cuda.is_available():
394
+ device = 'cuda'
395
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
396
+ device = 'mps'
397
+ print(f"using device : {device}")
398
+
399
+ # Load the dataset and create the data loader
400
+ print("Loading datasets...")
401
+ train_dataset = NpyDataset('edu_fineweb10B', 'edufineweb_train')
402
+ val_dataset = NpyDataset('edu_fineweb10B', 'edufineweb_val')
403
+ train_dataloader = CustomDataLoaderLite(train_dataset, batch_size=12, seq_len=512)
404
+ val_dataloader = CustomDataLoaderLite(val_dataset, batch_size=12, seq_len=512)
405
+
406
+ # Training loop
407
+ max_steps = 200
408
+ total_batch_size = 262144
409
+ B = 12
410
+ T = 512
411
+ grad_accum_steps = total_batch_size // (B * T)
412
+
413
+ # Set up the configuration
414
+ print("Setting up model configuration...")
415
+ config = GPTConfig(vocab_size=50304, block_size=512, n_layer=6, n_head=4, n_embd=256)
416
+
417
+ # Initialize the model
418
+ print("Initializing model...")
419
+ model = GPTWithMoE(config, num_experts=3, expert_layers=3, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device=device)
420
+ model.to(device)
421
+
422
+ # Load the saved model weights if they exist
423
+ save_path = "C:\\Users\\Admin\\MODELS\\moe_mcts_new.pt"
424
+ temp_save_path = "C:\\Users\\Admin\\MODELS\\moe_mcts_temp_new.pt"
425
+ if os.path.isfile(save_path):
426
+ print(f"Loading model weights from {save_path}...")
427
+ model.load_state_dict(torch.load(save_path))
428
+ print(f"Loaded model weights from {save_path}")
429
+
430
+ print("Configuring optimizer...")
431
+ optimizer = model.configure_optimizers(weight_decay=0.2, learning_rate=3e-3, device=device)
432
+
433
+ scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
434
+
435
+ train_losses = []
436
+ val_losses = []
437
+
438
+ scaler = torch.cuda.amp.GradScaler()
439
+
440
+ for i in range(max_steps):
441
+ t0 = time.time()
442
+ optimizer.zero_grad()
443
+ train_loss_accum = 0
444
+
445
+ model.train()
446
+ print(f"Training step {i + 1}/{max_steps}...")
447
+ for x, y in train_dataloader:
448
+ x, y = x.to(device), y.to(device)
449
+ with torch.cuda.amp.autocast():
450
+ logits, loss = model(x, y)
451
+ loss = loss / grad_accum_steps
452
+ train_loss_accum += loss.detach()
453
+ scaler.scale(loss).backward()
454
+
455
+ norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
456
+ scaler.step(optimizer)
457
+ scaler.update()
458
+
459
+ torch.cuda.synchronize()
460
+ t1 = time.time()
461
+ dt = (t1 - t0) * 1000
462
+ tokens_per_sec = (B * T * grad_accum_steps) / (t1 - t0)
463
+ train_losses.append(train_loss_accum.item())
464
+
465
+ torch.cuda.empty_cache()
466
+
467
+ # Validation with MCTS
468
+ model.eval()
469
+ val_loss = validate_with_mcts(model, val_dataloader, device, max_length=T, num_simulations=100, c_puct=1.0, top_k=10)
470
+ val_losses.append(val_loss)
471
+
472
+ scheduler.step(val_loss)
473
+
474
+ print(f"step {i} | train loss: {train_loss_accum.item():.6f} | val loss: {val_loss:.6f} | lr: {optimizer.param_groups[0]['lr']:.8f} | norm: {norm:.4f} | dt: {dt:.2f}ms | tok/sec: {tokens_per_sec}")
475
+
476
+ # Save model weights
477
+ torch.save(model.state_dict(), temp_save_path)
478
+ os.replace(temp_save_path, save_path)
479
+ print(f"Model saved at step {i+1} to {save_path}")
480
+
481
+ # Plotting the training and validation loss
482
+ plt.figure(figsize=(10, 5))
483
+ plt.plot(train_losses, label='Training Loss')
484
+ plt.plot(val_losses, label='Validation Loss')
485
+ plt.xlabel('Steps')
486
+ plt.ylabel('Loss')
487
+ plt.legend()
488
+ plt.show()
489
+
490
+ if __name__ == "__main__":
491
+ train_model()
492
+
493
+