From c2775df8e676ad4a5e1fea76bc781324b5625d54 Mon Sep 17 00:00:00 2001 From: rtz19970824 Date: Tue, 9 Jan 2018 20:09:48 +0800 Subject: [PATCH] modify game.py for multi-player --- AlphaGo/game.py | 26 ++++++++++++++++++-------- AlphaGo/model.py | 6 +++--- AlphaGo/play.py | 1 + 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/AlphaGo/game.py b/AlphaGo/game.py index c105522..abb6331 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -12,6 +12,7 @@ import numpy as np import sys, os import model from collections import deque + sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir)) from tianshou.core.mcts.mcts import MCTS @@ -19,6 +20,7 @@ import go import reversi import time + class Game: ''' Load the real game and trained weights. @@ -26,10 +28,11 @@ class Game: TODO : Maybe merge with the engine class in future, currently leave it untouched for interacting with Go UI. ''' - def __init__(self, name=None, role=None, debug=False, checkpoint_path=None): + + def __init__(self, name=None, role=None, debug=False, black_checkpoint_path=None, white_checkpoint_path=None): self.name = name - if role is None: - raise ValueError("Need a role!") + if role is None: + raise ValueError("Need a role!") self.role = role self.debug = debug if self.name == "go": @@ -49,8 +52,9 @@ class Game: else: raise ValueError(name + " is an unknown game...") - self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length, - checkpoint_path=checkpoint_path) + self.model = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length, + black_checkpoint_path=black_checkpoint_path, + white_checkpoint_path=white_checkpoint_path) self.latest_boards = deque(maxlen=self.history_length) for _ in range(self.history_length): self.latest_boards.append(self.board) @@ -72,7 +76,12 @@ class Game: self.komi = k def think(self, latest_boards, color): - mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], + if color == +1: + role = 'black' + if color == -1: + role = 'white' + evaluator = lambda state:self.model(role, state) + mcts = MCTS(self.game_engine, evaluator, [latest_boards, color], self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True) mcts.search(max_step=100) if self.debug: @@ -98,7 +107,8 @@ class Game: if self.name == "reversi": res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) if self.name == "go": - res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board, color, vertex) + res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board, + color, vertex) return res def think_play_move(self, color): @@ -128,8 +138,8 @@ class Game: print('') sys.stdout.flush() + if __name__ == "__main__": game = Game(name="reversi", role="black", checkpoint_path=None) game.debug = True game.think_play_move(utils.BLACK) - diff --git a/AlphaGo/model.py b/AlphaGo/model.py index 88fd199..7741eb6 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -168,7 +168,7 @@ class ResNet(object): self.__setattr__(scope + '_saver', tf.train.Saver(max_to_keep=0, var_list=self.__getattribute__(scope + '_var_list'))) - def __call__(self, state): + def __call__(self, role, state): """ :param history: a list, the history @@ -184,10 +184,10 @@ class ResNet(object): 'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history), self.history_length)) eval_state = self._history2state(history, color) - if color == +1: + if role == 'black': return self.sess.run([self.black_prob, self.black_v], feed_dict={self.x: eval_state, self.is_training: False}) - if color == -1: + if role == 'white': return self.sess.run([self.white_prob, self.white_v], feed_dict={self.x: eval_state, self.is_training: False}) diff --git a/AlphaGo/play.py b/AlphaGo/play.py index 5aaa6a2..e419f5b 100644 --- a/AlphaGo/play.py +++ b/AlphaGo/play.py @@ -24,6 +24,7 @@ class Data(object): def reset(self): self.__init__() + if __name__ == '__main__': """ Starting two different players which load network weights to evaluate the winning ratio.