mishasamin commited on
Commit
d78e144
·
1 Parent(s): b914774

Update quoridor/pytorch/NNet.py

Browse files
Files changed (1) hide show
  1. quoridor/pytorch/NNet.py +23 -3
quoridor/pytorch/NNet.py CHANGED
@@ -8,7 +8,7 @@ import math
8
  import sys
9
  sys.path.append('../../')
10
  from utils import *
11
- from pytorch_classification.utils import Bar, AverageMeter
12
  from NeuralNet import NeuralNet
13
 
14
  import argparse
@@ -24,7 +24,7 @@ from .QuoridorNNet import QuoridorNNet as qnnet
24
  args = dotdict({
25
  'lr': 0.00025,
26
  'dropout': 0.3,
27
- 'epochs': 8,
28
  'batch_size': 64,
29
  'cuda': torch.cuda.is_available(),
30
  'num_channels': 256,
@@ -69,7 +69,8 @@ class NNetWrapper(NeuralNet):
69
  while batch_idx < int(len(examples)/args.batch_size):
70
  sample_ids = np.random.randint(len(examples), size=args.batch_size)
71
  if withValids:
72
- boards, pis, vs, valids = list(zip(*[examples[i] for i in sample_ids]))
 
73
  else:
74
  boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
75
  boards = torch.FloatTensor(np.array(boards).astype(np.uint8))
@@ -185,3 +186,22 @@ class NNetWrapper(NeuralNet):
185
  self.nnet = self.nnet.to('cpu')
186
  self.nnet.load_state_dict(checkpoint['state_dict'])
187
  self.nnet.cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import sys
9
  sys.path.append('../../')
10
  from utils import *
11
+ from progress.bar import Bar
12
  from NeuralNet import NeuralNet
13
 
14
  import argparse
 
24
  args = dotdict({
25
  'lr': 0.00025,
26
  'dropout': 0.3,
27
+ 'epochs': 4,
28
  'batch_size': 64,
29
  'cuda': torch.cuda.is_available(),
30
  'num_channels': 256,
 
69
  while batch_idx < int(len(examples)/args.batch_size):
70
  sample_ids = np.random.randint(len(examples), size=args.batch_size)
71
  if withValids:
72
+ res = list(zip(*[examples[i] for i in sample_ids]))
73
+ boards, pis, vs, valids = res[0], res[1], res[2], res[3]
74
  else:
75
  boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
76
  boards = torch.FloatTensor(np.array(boards).astype(np.uint8))
 
186
  self.nnet = self.nnet.to('cpu')
187
  self.nnet.load_state_dict(checkpoint['state_dict'])
188
  self.nnet.cuda()
189
+
190
+ class AverageMeter(object):
191
+ """Computes and stores the average and current value
192
+ Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
193
+ """
194
+ def __init__(self):
195
+ self.reset()
196
+
197
+ def reset(self):
198
+ self.val = 0
199
+ self.avg = 0
200
+ self.sum = 0
201
+ self.count = 0
202
+
203
+ def update(self, val, n=1):
204
+ self.val = val
205
+ self.sum += val * n
206
+ self.count += n
207
+ self.avg = self.sum / self.count