Commit
·
cb658f7
1
Parent(s):
d78e144
Update Coach.py
Browse files
Coach.py
CHANGED
@@ -2,7 +2,8 @@ from collections import deque
|
|
2 |
from Arena import Arena
|
3 |
from MCTS import MCTS
|
4 |
import numpy as np
|
5 |
-
from
|
|
|
6 |
import time, os, sys
|
7 |
from pickle import Pickler, Unpickler
|
8 |
from random import shuffle
|
@@ -58,14 +59,15 @@ class Coach():
|
|
58 |
#self.game.print_board(canonicalBoard)
|
59 |
|
60 |
action = np.random.choice(len(pi), p=pi)
|
61 |
-
trainExamples.append([canonicalBoard, self.curPlayer, pi, None, valids])
|
62 |
board, self.curPlayer = self.game.getNextState(board, self.curPlayer, action)
|
63 |
|
64 |
r = self.game.getGameEnded(board, self.curPlayer)
|
65 |
|
66 |
if r!=0:
|
67 |
-
return [(x[0],x[2],r*x[1], x[4]) for x in trainExamples]
|
68 |
#return [(x[0],x[2],0) for x in trainExamples]
|
|
|
69 |
return []
|
70 |
|
71 |
def learn(self):
|
@@ -103,10 +105,9 @@ class Coach():
|
|
103 |
# save the iteration examples to the history
|
104 |
self.trainExamplesHistory.append(iterationTrainExamples)
|
105 |
trainStats = [0,0,0]
|
106 |
-
for
|
107 |
-
trainStats[res] += 1
|
108 |
print(trainStats)
|
109 |
-
|
110 |
if len(self.trainExamplesHistory) > self.args.numItersForTrainExamplesHistory:
|
111 |
print("len(trainExamplesHistory) =", len(self.trainExamplesHistory), " => remove the oldest trainExamples")
|
112 |
self.trainExamplesHistory.pop(0)
|
@@ -119,12 +120,12 @@ class Coach():
|
|
119 |
for e in self.trainExamplesHistory:
|
120 |
trainExamples.extend(e)
|
121 |
shuffle(trainExamples)
|
122 |
-
|
123 |
# training new network, keeping a copy of the old one
|
124 |
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
|
125 |
self.pnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
|
126 |
pmcts = MCTS(self.game, self.pnet, self.args)
|
127 |
-
|
128 |
self.nnet.train(trainExamples)
|
129 |
nmcts = MCTS(self.game, self.nnet, self.args)
|
130 |
|
|
|
2 |
from Arena import Arena
|
3 |
from MCTS import MCTS
|
4 |
import numpy as np
|
5 |
+
from progress.bar import Bar
|
6 |
+
from quoridor.pytorch.NNet import AverageMeter
|
7 |
import time, os, sys
|
8 |
from pickle import Pickler, Unpickler
|
9 |
from random import shuffle
|
|
|
59 |
#self.game.print_board(canonicalBoard)
|
60 |
|
61 |
action = np.random.choice(len(pi), p=pi)
|
62 |
+
trainExamples.append([canonicalBoard, self.curPlayer, pi, None, valids, episodeStep])
|
63 |
board, self.curPlayer = self.game.getNextState(board, self.curPlayer, action)
|
64 |
|
65 |
r = self.game.getGameEnded(board, self.curPlayer)
|
66 |
|
67 |
if r!=0:
|
68 |
+
return [(x[0],x[2],r*x[1], x[4], x[5], episodeStep) for x in trainExamples]
|
69 |
#return [(x[0],x[2],0) for x in trainExamples]
|
70 |
+
print("the game's not ended")
|
71 |
return []
|
72 |
|
73 |
def learn(self):
|
|
|
105 |
# save the iteration examples to the history
|
106 |
self.trainExamplesHistory.append(iterationTrainExamples)
|
107 |
trainStats = [0,0,0]
|
108 |
+
for res in iterationTrainExamples:
|
109 |
+
trainStats[res[2]] += 1
|
110 |
print(trainStats)
|
|
|
111 |
if len(self.trainExamplesHistory) > self.args.numItersForTrainExamplesHistory:
|
112 |
print("len(trainExamplesHistory) =", len(self.trainExamplesHistory), " => remove the oldest trainExamples")
|
113 |
self.trainExamplesHistory.pop(0)
|
|
|
120 |
for e in self.trainExamplesHistory:
|
121 |
trainExamples.extend(e)
|
122 |
shuffle(trainExamples)
|
123 |
+
|
124 |
# training new network, keeping a copy of the old one
|
125 |
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
|
126 |
self.pnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
|
127 |
pmcts = MCTS(self.game, self.pnet, self.args)
|
128 |
+
|
129 |
self.nnet.train(trainExamples)
|
130 |
nmcts = MCTS(self.game, self.nnet, self.args)
|
131 |
|