modify game.py for multi-player

This commit is contained in:
rtz19970824 2018-01-09 20:09:48 +08:00
parent eb0ce95919
commit c2775df8e6
3 changed files with 22 additions and 11 deletions

View File

@ -12,6 +12,7 @@ import numpy as np
import sys, os
import model
from collections import deque
sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir))
from tianshou.core.mcts.mcts import MCTS
@ -19,6 +20,7 @@ import go
import reversi
import time
class Game:
'''
Load the real game and trained weights.
@ -26,7 +28,8 @@ class Game:
TODO : Maybe merge with the engine class in future,
currently leave it untouched for interacting with Go UI.
'''
def __init__(self, name=None, role=None, debug=False, checkpoint_path=None):
def __init__(self, name=None, role=None, debug=False, black_checkpoint_path=None, white_checkpoint_path=None):
self.name = name
if role is None:
raise ValueError("Need a role!")
@ -49,8 +52,9 @@ class Game:
else:
raise ValueError(name + " is an unknown game...")
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length,
checkpoint_path=checkpoint_path)
self.model = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length,
black_checkpoint_path=black_checkpoint_path,
white_checkpoint_path=white_checkpoint_path)
self.latest_boards = deque(maxlen=self.history_length)
for _ in range(self.history_length):
self.latest_boards.append(self.board)
@ -72,7 +76,12 @@ class Game:
self.komi = k
def think(self, latest_boards, color):
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color],
if color == +1:
role = 'black'
if color == -1:
role = 'white'
evaluator = lambda state:self.model(role, state)
mcts = MCTS(self.game_engine, evaluator, [latest_boards, color],
self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True)
mcts.search(max_step=100)
if self.debug:
@ -98,7 +107,8 @@ class Game:
if self.name == "reversi":
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
if self.name == "go":
res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board, color, vertex)
res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board,
color, vertex)
return res
def think_play_move(self, color):
@ -128,8 +138,8 @@ class Game:
print('')
sys.stdout.flush()
if __name__ == "__main__":
game = Game(name="reversi", role="black", checkpoint_path=None)
game.debug = True
game.think_play_move(utils.BLACK)

View File

@ -168,7 +168,7 @@ class ResNet(object):
self.__setattr__(scope + '_saver',
tf.train.Saver(max_to_keep=0, var_list=self.__getattribute__(scope + '_var_list')))
def __call__(self, state):
def __call__(self, role, state):
"""
:param history: a list, the history
@ -184,10 +184,10 @@ class ResNet(object):
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
self.history_length))
eval_state = self._history2state(history, color)
if color == +1:
if role == 'black':
return self.sess.run([self.black_prob, self.black_v],
feed_dict={self.x: eval_state, self.is_training: False})
if color == -1:
if role == 'white':
return self.sess.run([self.white_prob, self.white_v],
feed_dict={self.x: eval_state, self.is_training: False})

View File

@ -24,6 +24,7 @@ class Data(object):
def reset(self):
self.__init__()
if __name__ == '__main__':
"""
Starting two different players which load network weights to evaluate the winning ratio.