modify game.py for multi-player

This commit is contained in:
rtz19970824 2018-01-09 20:09:48 +08:00
parent eb0ce95919
commit c2775df8e6
3 changed files with 22 additions and 11 deletions

View File

@ -12,6 +12,7 @@ import numpy as np
import sys, os import sys, os
import model import model
from collections import deque from collections import deque
sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir)) sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir))
from tianshou.core.mcts.mcts import MCTS from tianshou.core.mcts.mcts import MCTS
@ -19,6 +20,7 @@ import go
import reversi import reversi
import time import time
class Game: class Game:
''' '''
Load the real game and trained weights. Load the real game and trained weights.
@ -26,10 +28,11 @@ class Game:
TODO : Maybe merge with the engine class in future, TODO : Maybe merge with the engine class in future,
currently leave it untouched for interacting with Go UI. currently leave it untouched for interacting with Go UI.
''' '''
def __init__(self, name=None, role=None, debug=False, checkpoint_path=None):
def __init__(self, name=None, role=None, debug=False, black_checkpoint_path=None, white_checkpoint_path=None):
self.name = name self.name = name
if role is None: if role is None:
raise ValueError("Need a role!") raise ValueError("Need a role!")
self.role = role self.role = role
self.debug = debug self.debug = debug
if self.name == "go": if self.name == "go":
@ -49,8 +52,9 @@ class Game:
else: else:
raise ValueError(name + " is an unknown game...") raise ValueError(name + " is an unknown game...")
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length, self.model = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length,
checkpoint_path=checkpoint_path) black_checkpoint_path=black_checkpoint_path,
white_checkpoint_path=white_checkpoint_path)
self.latest_boards = deque(maxlen=self.history_length) self.latest_boards = deque(maxlen=self.history_length)
for _ in range(self.history_length): for _ in range(self.history_length):
self.latest_boards.append(self.board) self.latest_boards.append(self.board)
@ -72,7 +76,12 @@ class Game:
self.komi = k self.komi = k
def think(self, latest_boards, color): def think(self, latest_boards, color):
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], if color == +1:
role = 'black'
if color == -1:
role = 'white'
evaluator = lambda state:self.model(role, state)
mcts = MCTS(self.game_engine, evaluator, [latest_boards, color],
self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True) self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True)
mcts.search(max_step=100) mcts.search(max_step=100)
if self.debug: if self.debug:
@ -98,7 +107,8 @@ class Game:
if self.name == "reversi": if self.name == "reversi":
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
if self.name == "go": if self.name == "go":
res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board, color, vertex) res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board,
color, vertex)
return res return res
def think_play_move(self, color): def think_play_move(self, color):
@ -128,8 +138,8 @@ class Game:
print('') print('')
sys.stdout.flush() sys.stdout.flush()
if __name__ == "__main__": if __name__ == "__main__":
game = Game(name="reversi", role="black", checkpoint_path=None) game = Game(name="reversi", role="black", checkpoint_path=None)
game.debug = True game.debug = True
game.think_play_move(utils.BLACK) game.think_play_move(utils.BLACK)

View File

@ -168,7 +168,7 @@ class ResNet(object):
self.__setattr__(scope + '_saver', self.__setattr__(scope + '_saver',
tf.train.Saver(max_to_keep=0, var_list=self.__getattribute__(scope + '_var_list'))) tf.train.Saver(max_to_keep=0, var_list=self.__getattribute__(scope + '_var_list')))
def __call__(self, state): def __call__(self, role, state):
""" """
:param history: a list, the history :param history: a list, the history
@ -184,10 +184,10 @@ class ResNet(object):
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history), 'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
self.history_length)) self.history_length))
eval_state = self._history2state(history, color) eval_state = self._history2state(history, color)
if color == +1: if role == 'black':
return self.sess.run([self.black_prob, self.black_v], return self.sess.run([self.black_prob, self.black_v],
feed_dict={self.x: eval_state, self.is_training: False}) feed_dict={self.x: eval_state, self.is_training: False})
if color == -1: if role == 'white':
return self.sess.run([self.white_prob, self.white_v], return self.sess.run([self.white_prob, self.white_v],
feed_dict={self.x: eval_state, self.is_training: False}) feed_dict={self.x: eval_state, self.is_training: False})

View File

@ -24,6 +24,7 @@ class Data(object):
def reset(self): def reset(self):
self.__init__() self.__init__()
if __name__ == '__main__': if __name__ == '__main__':
""" """
Starting two different players which load network weights to evaluate the winning ratio. Starting two different players which load network weights to evaluate the winning ratio.