From aa6b5434c673c8d7c83c290bacd4a92b1ac0832b Mon Sep 17 00:00:00 2001 From: Dong Yan Date: Tue, 26 Dec 2017 14:46:14 +0800 Subject: [PATCH] add debuf info for mcts and add softmax for the prior --- AlphaGo/game.py | 15 +++++++++++---- AlphaGo/model.py | 4 ++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/AlphaGo/game.py b/AlphaGo/game.py index 72ae2e0..ec39f94 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -71,6 +71,13 @@ class Game: mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True) mcts.search(max_step=100) + if self.debug: + file = open("mcts_debug.log", 'ab') + np.savetxt(file, mcts.root.Q, header="\nQ value : ", fmt='%.4f', newline=", ") + np.savetxt(file, mcts.root.W, header="\nW value : ", fmt='%.4f', newline=", ") + np.savetxt(file, mcts.root.N, header="\nN value : ", fmt="%d", newline=", ") + np.savetxt(file, mcts.root.prior, header="\nprior : ", fmt='%.4f', newline=", ") + file.close() temp = 1 prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp) choice = np.random.choice(self.size ** 2 + 1, 1, p=prob).tolist()[0] @@ -119,7 +126,7 @@ class Game: sys.stdout.flush() if __name__ == "__main__": - print("test game.py") - #file = open("debug.txt", "a") - #file.write("mcts check\n") - #file.close() + game = Game(name="go", checkpoint_path="./checkpoint") + game.debug = True + game.think_play_move(utils.BLACK) + diff --git a/AlphaGo/model.py b/AlphaGo/model.py index 0549f41..704a034 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -80,7 +80,7 @@ class Data(object): class ResNet(object): - def __init__(self, board_size, action_num, history_length=1, residual_block_num=20, checkpoint_path=None): + def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, checkpoint_path=None): """ the resnet model @@ -161,7 +161,7 @@ class ResNet(object): 'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history), self.history_length)) state = self._history2state(history, color) - return self.sess.run([self.p, self.v], feed_dict={self.x: state, self.is_training: False}) + return self.sess.run([tf.nn.softmax(self.p), self.v], feed_dict={self.x: state, self.is_training: False}) def _history2state(self, history, color): """