File size: 5,674 Bytes
fe56a4f b914774 fe56a4f b914774 fe56a4f b914774 fe56a4f b914774 fe56a4f b914774 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import numpy as np
from progress.bar import Bar
from quoridor.pytorch.NNet import AverageMeter
import time
#import os
class Arena():
"""
An Arena class where any 2 agents can be pit against each other.
"""
def __init__(self, player1, player2, game, display=None):
"""
Input:
player 1,2: two functions that takes board as input, return action
game: Game object
display: a function that takes board as input and prints it (e.g.
display in othello/OthelloGame). Is necessary for verbose
mode.
see othello/OthelloPlayers.py for an example. See pit.py for pitting
human players/other baselines with each other.
"""
self.player1 = player1
self.player2 = player2
self.game = game
self.display = display
def playGame(self, verbose=False):
"""
Executes one episode of a game.
Returns:
either
winner: player who won the game (1 if player1, -1 if player2)
or
draw result returned from the game that is neither 1, -1, nor 0.
"""
players = [self.player2, None, self.player1]
curPlayer = 1
board = self.game.getInitBoard()
it = 0
while self.game.getGameEnded(board, curPlayer)==0 and it<200:
it+=1
if verbose:
assert(self.display)
print("Turn ", str(it), "Player ", str(curPlayer))
#self.display(board)
if players[curPlayer+1].__name__ != '<lambda>': # new
#os.system('clear')
self.display(self.game.getCanonicalForm(board, curPlayer), curPlayer)
action = players[curPlayer+1](self.game.getCanonicalForm(board, curPlayer))
valids = self.game.getValidMoves(self.game.getCanonicalForm(board, curPlayer),1)
if valids[action]==0:
print("invalid action", action)
return -curPlayer
#print ""
#print(action)
#return 0
assert valids[action] >0
#if verbose:
# print("Action index ", str(action))
board, curPlayer = self.game.getNextState(board, curPlayer, action)
if verbose and players[curPlayer+1].__name__ == '<lambda>': # new
#os.system('clear')
self.display(self.game.getCanonicalForm(board, -curPlayer), -curPlayer)
if verbose:
assert(self.display)
print("Game over: Turn ", str(it), "Result ", str(self.game.getGameEnded(board, 1)))
if verbose and players[curPlayer+1].__name__ == '<lambda>':
self.display(self.game.getCanonicalForm(board, -curPlayer), -curPlayer)
else:
self.display(self.game.getCanonicalForm(board, curPlayer), curPlayer)
# self.display(self.game.getCanonicalForm(board, -curPlayer), -curPlayer)
#self.display(board)
return self.game.getGameEnded(board, 1)
def playGames(self, num, verbose=False):
"""
Plays num games in which player1 starts num/2 games and player2 starts
num/2 games.
Returns:
oneWon: games won by player1
twoWon: games won by player2
draws: games won by nobody
"""
eps_time = AverageMeter()
bar = Bar('Arena.playGames', max=num)
end = time.time()
eps = 0
maxeps = int(num)
num = int(num/2)
oneWon = 0
twoWon = 0
draws = 0
for _ in range(num):
gameResult = self.playGame(verbose=verbose)
if gameResult==1:
oneWon+=1
elif gameResult==-1:
twoWon+=1
else:
draws+=1
# bookkeeping + plot progress
eps += 1
eps_time.update(time.time() - end)
end = time.time()
bar.suffix = '({eps}/{maxeps}) Eps Time: {et:.3f}s | Total: {total:} | ETA: {eta:}'.format(eps=eps+1, maxeps=maxeps, et=eps_time.avg,
total=bar.elapsed_td, eta=bar.eta_td)
bar.next()
self.player1, self.player2 = self.player2, self.player1
for _ in range(num):
gameResult = self.playGame(verbose=verbose)
if gameResult==-1:
oneWon+=1
elif gameResult==1:
twoWon+=1
else:
draws+=1
# bookkeeping + plot progress
eps += 1
eps_time.update(time.time() - end)
end = time.time()
bar.suffix = '({eps}/{maxeps}) Eps Time: {et:.3f}s | Total: {total:} | ETA: {eta:}'.format(eps=eps+1, maxeps=num, et=eps_time.avg,
total=bar.elapsed_td, eta=bar.eta_td)
bar.next()
bar.finish()
return oneWon, twoWon, draws
class AverageMeter(object):
"""Computes and stores the average and current value
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count |