|
from collections import deque |
|
from Arena import Arena |
|
from MCTS import MCTS |
|
import numpy as np |
|
from progress.bar import Bar |
|
from quoridor.pytorch.NNet import AverageMeter |
|
import time, os, sys |
|
from pickle import Pickler, Unpickler |
|
from random import shuffle |
|
|
|
|
|
class Coach(): |
|
""" |
|
This class executes the self-play + learning. It uses the functions defined |
|
in Game and NeuralNet. args are specified in main.py. |
|
""" |
|
def __init__(self, game, nnet, args): |
|
self.game = game |
|
self.nnet = nnet |
|
self.pnet = self.nnet.__class__(self.game) |
|
self.args = args |
|
self.mcts = MCTS(self.game, self.nnet, self.args) |
|
self.trainExamplesHistory = [] |
|
self.skipFirstSelfPlay = False |
|
|
|
def executeEpisode(self): |
|
""" |
|
This function executes one episode of self-play, starting with player 1. |
|
As the game is played, each turn is added as a training example to |
|
trainExamples. The game is played till the game ends. After the game |
|
ends, the outcome of the game is used to assign values to each example |
|
in trainExamples. |
|
|
|
It uses a temp=1 if episodeStep < tempThreshold, and thereafter |
|
uses temp=0. |
|
|
|
Returns: |
|
trainExamples: a list of examples of the form (canonicalBoard,pi,v) |
|
pi is the MCTS informed policy vector, v is +1 if |
|
the player eventually won the game, else -1. |
|
""" |
|
trainExamples = [] |
|
board = self.game.getInitBoard() |
|
self.curPlayer = 1 |
|
episodeStep = 0 |
|
while True and episodeStep<200: |
|
episodeStep += 1 |
|
canonicalBoard = self.game.getCanonicalForm(board,self.curPlayer) |
|
valids = self.game.getValidMoves(canonicalBoard, 1) |
|
temp = int(episodeStep < self.args.tempThreshold) |
|
|
|
pi = self.mcts.getActionProb(canonicalBoard, temp=temp) |
|
|
|
if np.sum(pi) == 0: break |
|
|
|
|
|
|
|
|
|
|
|
|
|
action = np.random.choice(len(pi), p=pi) |
|
trainExamples.append([canonicalBoard, self.curPlayer, pi, None, valids, episodeStep]) |
|
board, self.curPlayer = self.game.getNextState(board, self.curPlayer, action) |
|
|
|
r = self.game.getGameEnded(board, self.curPlayer) |
|
|
|
if r!=0: |
|
return [(x[0],x[2],r*x[1], x[4], x[5], episodeStep) for x in trainExamples] |
|
|
|
print("the game's not ended") |
|
return [] |
|
|
|
def learn(self): |
|
""" |
|
Performs numIters iterations with numEps episodes of self-play in each |
|
iteration. After every iteration, it retrains neural network with |
|
examples in trainExamples (which has a maximium length of maxlenofQueue). |
|
It then pits the new neural network against the old one and accepts it |
|
only if it wins >= updateThreshold fraction of games. |
|
""" |
|
|
|
for i in range(1, self.args.numIters+1): |
|
|
|
print('------ITER ' + str(i) + '------') |
|
|
|
if not self.skipFirstSelfPlay or i>1: |
|
iterationTrainExamples = deque([], maxlen=self.args.maxlenOfQueue) |
|
|
|
eps_time = AverageMeter() |
|
bar = Bar('Self Play', max=self.args.numEps) |
|
end = time.time() |
|
|
|
for eps in range(self.args.numEps): |
|
self.mcts = MCTS(self.game, self.nnet, self.args) |
|
iterationTrainExamples += self.executeEpisode() |
|
|
|
|
|
eps_time.update(time.time() - end) |
|
end = time.time() |
|
bar.suffix = '({eps}/{maxeps}) Eps Time: {et:.3f}s | Total: {total:} | ETA: {eta:}'.format(eps=eps+1, maxeps=self.args.numEps, et=eps_time.avg, |
|
total=bar.elapsed_td, eta=bar.eta_td) |
|
bar.next() |
|
bar.finish() |
|
|
|
|
|
self.trainExamplesHistory.append(iterationTrainExamples) |
|
trainStats = [0,0,0] |
|
for res in iterationTrainExamples: |
|
trainStats[res[2]] += 1 |
|
print(trainStats) |
|
if len(self.trainExamplesHistory) > self.args.numItersForTrainExamplesHistory: |
|
print("len(trainExamplesHistory) =", len(self.trainExamplesHistory), " => remove the oldest trainExamples") |
|
self.trainExamplesHistory.pop(0) |
|
|
|
|
|
self.saveTrainExamples(i-1) |
|
|
|
|
|
trainExamples = [] |
|
for e in self.trainExamplesHistory: |
|
trainExamples.extend(e) |
|
shuffle(trainExamples) |
|
|
|
|
|
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar') |
|
self.pnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar') |
|
pmcts = MCTS(self.game, self.pnet, self.args) |
|
|
|
self.nnet.train(trainExamples) |
|
nmcts = MCTS(self.game, self.nnet, self.args) |
|
|
|
print('PITTING AGAINST PREVIOUS VERSION') |
|
arena = Arena(lambda x: np.argmax(pmcts.getActionProb(x, temp=0)), |
|
lambda x: np.argmax(nmcts.getActionProb(x, temp=0)), self.game) |
|
pwins, nwins, draws = arena.playGames(self.args.arenaCompare) |
|
|
|
print('NEW/PREV WINS : %d / %d ; DRAWS : %d' % (nwins, pwins, draws)) |
|
if pwins+nwins > 0 and float(nwins)/(pwins+nwins) < self.args.updateThreshold: |
|
print('REJECTING NEW MODEL') |
|
self.nnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar') |
|
else: |
|
print('ACCEPTING NEW MODEL') |
|
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename=self.getCheckpointFile(i)) |
|
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='best.pth.tar') |
|
|
|
def getCheckpointFile(self, iteration): |
|
return 'checkpoint_' + str(iteration) + '.pth.tar' |
|
|
|
def saveTrainExamples(self, iteration): |
|
folder = self.args.checkpoint |
|
if not os.path.exists(folder): |
|
os.makedirs(folder) |
|
filename = os.path.join(folder, self.getCheckpointFile(iteration)+".examples") |
|
with open(filename, "wb+") as f: |
|
Pickler(f).dump(self.trainExamplesHistory) |
|
f.closed |
|
|
|
def loadTrainExamples(self): |
|
modelFile = os.path.join(self.args.load_folder_examples_file[0], self.args.load_folder_examples_file[1]) |
|
examplesFile = modelFile+".examples" |
|
if not os.path.isfile(examplesFile): |
|
print(examplesFile) |
|
r = input("File with trainExamples not found. Continue? [y|n]") |
|
if r != "y": |
|
sys.exit() |
|
else: |
|
print("File with trainExamples found. Read it.") |
|
with open(examplesFile, "rb") as f: |
|
self.trainExamplesHistory = Unpickler(f).load() |
|
f.closed |
|
|
|
self.skipFirstSelfPlay = True |
|
|