sigmoidneuron123 commited on
Commit
a86e2f9
·
verified ·
1 Parent(s): 9fa5fb2

Create selfchess-colab.py

Browse files
Files changed (1) hide show
  1. selfchess-colab.py +224 -0
selfchess-colab.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system('pip install chess')
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import chess
7
+ import os
8
+ import chess.engine as eng
9
+ import torch.multiprocessing as mp
10
+
11
+ # CONFIGURATION
12
+ CONFIG = {
13
+ "stockfish_path": "/usr/games/stockfish",
14
+ "model_path": "NeoChess/chessy_model.pth",
15
+ "backup_model_path": "NeoChess/chessy_modelt-1.pth",
16
+ "device": torch.device("cuda"),
17
+ "learning_rate": 1e-4,
18
+ "num_games": 30,
19
+ "num_epochs": 10,
20
+ "stockfish_time_limit": 1.0,
21
+ "search_depth": 1,
22
+ "epsilon": 4
23
+ }
24
+
25
+ device = CONFIG["device"]
26
+
27
+ def board_to_tensor(board):
28
+ piece_encoding = {
29
+ 'P': 1, 'N': 2, 'B': 3, 'R': 4, 'Q': 5, 'K': 6,
30
+ 'p': 7, 'n': 8, 'b': 9, 'r': 10, 'q': 11, 'k': 12
31
+ }
32
+
33
+ tensor = torch.zeros(64, dtype=torch.long)
34
+ for square in chess.SQUARES:
35
+ piece = board.piece_at(square)
36
+ if piece:
37
+ tensor[square] = piece_encoding[piece.symbol()]
38
+ else:
39
+ tensor[square] = 0
40
+
41
+ return tensor.unsqueeze(0)
42
+
43
+ class NN1(nn.Module):
44
+ def __init__(self):
45
+ super().__init__()
46
+ self.embedding = nn.Embedding(13, 64)
47
+ self.attention = nn.MultiheadAttention(embed_dim=64, num_heads=16)
48
+ self.neu = 512
49
+ self.neurons = nn.Sequential(
50
+ nn.Linear(4096, self.neu),
51
+ nn.ReLU(),
52
+ nn.Linear(self.neu, self.neu),
53
+ nn.ReLU(),
54
+ nn.Linear(self.neu, self.neu),
55
+ nn.ReLU(),
56
+ nn.Linear(self.neu, self.neu),
57
+ nn.ReLU(),
58
+ nn.Linear(self.neu, self.neu),
59
+ nn.ReLU(),
60
+ nn.Linear(self.neu, self.neu),
61
+ nn.ReLU(),
62
+ nn.Linear(self.neu, self.neu),
63
+ nn.ReLU(),
64
+ nn.Linear(self.neu, self.neu),
65
+ nn.ReLU(),
66
+ nn.Linear(self.neu, self.neu),
67
+ nn.ReLU(),
68
+ nn.Linear(self.neu, self.neu),
69
+ nn.ReLU(),
70
+ nn.Linear(self.neu, self.neu),
71
+ nn.ReLU(),
72
+ nn.Linear(self.neu, self.neu),
73
+ nn.ReLU(),
74
+ nn.Linear(self.neu, self.neu),
75
+ nn.ReLU(),
76
+ nn.Linear(self.neu, 64),
77
+ nn.ReLU(),
78
+ nn.Linear(64, 4)
79
+ )
80
+
81
+ def forward(self, x):
82
+ x = self.embedding(x)
83
+ x = x.permute(1, 0, 2)
84
+ attn_output, _ = self.attention(x, x, x)
85
+ x = attn_output.permute(1, 0, 2).contiguous()
86
+ x = x.view(x.size(0), -1)
87
+ x = self.neurons(x)
88
+ return x
89
+
90
+ model = NN1().to(device)
91
+ optimizer = optim.Adam(model.parameters(), lr=CONFIG["learning_rate"])
92
+
93
+ try:
94
+ model.load_state_dict(torch.load(CONFIG["model_path"], map_location=device))
95
+ print(f"Loaded model from {CONFIG['model_path']}")
96
+ except FileNotFoundError:
97
+ try:
98
+ model.load_state_dict(torch.load(CONFIG["backup_model_path"], map_location=device))
99
+ print(f"Loaded backup model from {CONFIG['backup_model_path']}")
100
+ except FileNotFoundError:
101
+ print("No model file found, starting from scratch.")
102
+
103
+ model.train()
104
+ criterion = nn.MSELoss()
105
+ engine = eng.SimpleEngine.popen_uci(CONFIG["stockfish_path"])
106
+ lim = eng.Limit(time=CONFIG["stockfish_time_limit"])
107
+
108
+ def get_evaluation(board):
109
+ """
110
+ Returns the evaluation of the board from the perspective of the current player.
111
+ The model's output is from White's perspective.
112
+ """
113
+ tensor = board_to_tensor(board).to(device)
114
+ with torch.no_grad():
115
+ evaluation = model(tensor)[0][0].item()
116
+
117
+ if board.turn == chess.WHITE:
118
+ return evaluation
119
+ else:
120
+ return -evaluation
121
+
122
+ def search(board, depth, alpha, beta):
123
+ """
124
+ A negamax search function.
125
+ """
126
+ if depth == 0 or board.is_game_over():
127
+ return get_evaluation(board)
128
+
129
+ max_eval = float('-inf')
130
+ for move in board.legal_moves:
131
+ board.push(move)
132
+ eval = -search(board, depth - 1, -beta, -alpha)
133
+ board.pop()
134
+ max_eval = max(max_eval, eval)
135
+ alpha = max(alpha, eval)
136
+ if alpha >= beta:
137
+ break
138
+ return max_eval
139
+
140
+
141
+
142
+
143
+ def game_gen(engine_side):
144
+ data = []
145
+ mc = 0
146
+ board = chess.Board()
147
+ while not board.is_game_over():
148
+ is_bot_turn = board.turn != engine_side
149
+
150
+ if is_bot_turn:
151
+ evaling = {}
152
+ for move in board.legal_moves:
153
+ board.push(move)
154
+ evaling[move] = -search(board, depth=CONFIG["search_depth"], alpha=float('-inf'), beta=float('inf'))
155
+ board.pop()
156
+
157
+ if not evaling:
158
+ break
159
+
160
+ keys = list(evaling.keys())
161
+ logits = torch.tensor(list(evaling.values())).to(device)
162
+ probs = torch.softmax(logits,dim=0)
163
+ epsilon = min(CONFIG["epsilon"],len(keys))
164
+ bests = torch.multinomial(probs,num_samples=epsilon,replacement=False)
165
+ best_idx = bests[torch.argmax(logits[bests])]
166
+ move = keys[best_idx.item()]
167
+
168
+ else:
169
+ result = engine.play(board, lim)
170
+ move = result.move
171
+
172
+ if is_bot_turn:
173
+ data.append({
174
+ 'fen': board.fen(),
175
+ 'move_number': mc,
176
+ })
177
+
178
+ board.push(move)
179
+ mc += 1
180
+
181
+ result = board.result()
182
+ c = 0
183
+ if result == '1-0':
184
+ c = 10.0
185
+ elif result == '0-1':
186
+ c = -10.0
187
+ return data, c, mc
188
+ def train(data, c, mc):
189
+ for entry in data:
190
+ tensor = board_to_tensor(chess.Board(entry['fen'])).to(device)
191
+ target = torch.tensor(c * entry['move_number'] / mc, dtype=torch.float32).to(device)
192
+ output = model(tensor)[0][0]
193
+ loss = criterion(output, target)
194
+ optimizer.zero_grad()
195
+ loss.backward()
196
+ optimizer.step()
197
+
198
+ print(f"Saving model to {CONFIG['model_path']}")
199
+ torch.save(model.state_dict(), CONFIG["model_path"])
200
+ return
201
+ def main():
202
+ for i in range(CONFIG["num_epochs"]):
203
+ mp.set_start_method('spawn', force=True)
204
+ num_games = CONFIG['num_games']
205
+ num_instances = mp.cpu_count()
206
+ print(f"Saving backup model to {CONFIG['backup_model_path']}")
207
+ torch.save(model.state_dict(), CONFIG["backup_model_path"])
208
+ with mp.Pool(processes=num_instances) as pool:
209
+ results_self = pool.starmap(game_gen, [(None,) for _ in range(num_games // 3)])
210
+ results_white = pool.starmap(game_gen, [(chess.WHITE,) for _ in range(num_games // 3)])
211
+ results_black = pool.starmap(game_gen, [(chess.BLACK,) for _ in range(num_games // 3)])
212
+ results = []
213
+ for s, w, b in zip(results_self, results_white, results_black):
214
+ results.extend([s, w, b])
215
+ for batch in results:
216
+ data, c, mc = batch
217
+ print(f"Saving backup model to {CONFIG['backup_model_path']}")
218
+ torch.save(model.state_dict(), CONFIG["backup_model_path"])
219
+ if data:
220
+ train(data, c, mc)
221
+ print("Training complete.")
222
+ engine.quit()
223
+ if __name__ == "__main__":
224
+ main()