mishasamin commited on
Commit
cb658f7
·
1 Parent(s): d78e144

Update Coach.py

Browse files
Files changed (1) hide show
  1. Coach.py +9 -8
Coach.py CHANGED
@@ -2,7 +2,8 @@ from collections import deque
2
  from Arena import Arena
3
  from MCTS import MCTS
4
  import numpy as np
5
- from pytorch_classification.utils import Bar, AverageMeter
 
6
  import time, os, sys
7
  from pickle import Pickler, Unpickler
8
  from random import shuffle
@@ -58,14 +59,15 @@ class Coach():
58
  #self.game.print_board(canonicalBoard)
59
 
60
  action = np.random.choice(len(pi), p=pi)
61
- trainExamples.append([canonicalBoard, self.curPlayer, pi, None, valids])
62
  board, self.curPlayer = self.game.getNextState(board, self.curPlayer, action)
63
 
64
  r = self.game.getGameEnded(board, self.curPlayer)
65
 
66
  if r!=0:
67
- return [(x[0],x[2],r*x[1], x[4]) for x in trainExamples]
68
  #return [(x[0],x[2],0) for x in trainExamples]
 
69
  return []
70
 
71
  def learn(self):
@@ -103,10 +105,9 @@ class Coach():
103
  # save the iteration examples to the history
104
  self.trainExamplesHistory.append(iterationTrainExamples)
105
  trainStats = [0,0,0]
106
- for _,_,res, _ in iterationTrainExamples:
107
- trainStats[res] += 1
108
  print(trainStats)
109
-
110
  if len(self.trainExamplesHistory) > self.args.numItersForTrainExamplesHistory:
111
  print("len(trainExamplesHistory) =", len(self.trainExamplesHistory), " => remove the oldest trainExamples")
112
  self.trainExamplesHistory.pop(0)
@@ -119,12 +120,12 @@ class Coach():
119
  for e in self.trainExamplesHistory:
120
  trainExamples.extend(e)
121
  shuffle(trainExamples)
122
-
123
  # training new network, keeping a copy of the old one
124
  self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
125
  self.pnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
126
  pmcts = MCTS(self.game, self.pnet, self.args)
127
-
128
  self.nnet.train(trainExamples)
129
  nmcts = MCTS(self.game, self.nnet, self.args)
130
 
 
2
  from Arena import Arena
3
  from MCTS import MCTS
4
  import numpy as np
5
+ from progress.bar import Bar
6
+ from quoridor.pytorch.NNet import AverageMeter
7
  import time, os, sys
8
  from pickle import Pickler, Unpickler
9
  from random import shuffle
 
59
  #self.game.print_board(canonicalBoard)
60
 
61
  action = np.random.choice(len(pi), p=pi)
62
+ trainExamples.append([canonicalBoard, self.curPlayer, pi, None, valids, episodeStep])
63
  board, self.curPlayer = self.game.getNextState(board, self.curPlayer, action)
64
 
65
  r = self.game.getGameEnded(board, self.curPlayer)
66
 
67
  if r!=0:
68
+ return [(x[0],x[2],r*x[1], x[4], x[5], episodeStep) for x in trainExamples]
69
  #return [(x[0],x[2],0) for x in trainExamples]
70
+ print("the game's not ended")
71
  return []
72
 
73
  def learn(self):
 
105
  # save the iteration examples to the history
106
  self.trainExamplesHistory.append(iterationTrainExamples)
107
  trainStats = [0,0,0]
108
+ for res in iterationTrainExamples:
109
+ trainStats[res[2]] += 1
110
  print(trainStats)
 
111
  if len(self.trainExamplesHistory) > self.args.numItersForTrainExamplesHistory:
112
  print("len(trainExamplesHistory) =", len(self.trainExamplesHistory), " => remove the oldest trainExamples")
113
  self.trainExamplesHistory.pop(0)
 
120
  for e in self.trainExamplesHistory:
121
  trainExamples.extend(e)
122
  shuffle(trainExamples)
123
+
124
  # training new network, keeping a copy of the old one
125
  self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
126
  self.pnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
127
  pmcts = MCTS(self.game, self.pnet, self.args)
128
+
129
  self.nnet.train(trainExamples)
130
  nmcts = MCTS(self.game, self.nnet, self.args)
131