From 7f0565a5f65b7784ba7145bcce237a09aff8f632 Mon Sep 17 00:00:00 2001 From: Dong Yan Date: Tue, 26 Dec 2017 22:19:10 +0800 Subject: [PATCH] variable rename and delete redundant code --- AlphaGo/game.py | 9 +++------ tianshou/core/mcts/mcts.py | 13 +++++++------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/AlphaGo/game.py b/AlphaGo/game.py index d123a92..f17c7af 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -46,7 +46,8 @@ class Game: else: raise ValueError(name + " is an unknown game...") - self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length, checkpoint_path=checkpoint_path) + self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length, + checkpoint_path=checkpoint_path) self.latest_boards = deque(maxlen=self.history_length) for _ in range(self.history_length): self.latest_boards.append(self.board) @@ -91,11 +92,7 @@ class Game: # this function can be called directly to play the opponent's move if vertex == utils.PASS: return True - # TODO this implementation is not very elegant - if self.name == "go": - res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) - elif self.name == "reversi": - res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) + res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) return res def think_play_move(self, color): diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index 5c96d38..3d547c6 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -129,6 +129,7 @@ class ActionNode(object): self.mcts.action_selection_time += time.time() - head return self.parent, self.action else: + # self.next_state is None means we have reach the terminate state self.mcts.action_selection_time += time.time() - head return self.parent, self.action @@ -147,20 +148,20 @@ class ActionNode(object): class MCTS(object): - def __init__(self, simulator, evaluator, root, action_num, method="UCT", + def __init__(self, simulator, evaluator, start_state, action_num, method="UCT", role="unknown", debug=False, inverse=False): self.simulator = simulator self.evaluator = evaluator self.role = role self.debug = debug - prior, _ = self.evaluator(root) + prior, _ = self.evaluator(start_state) self.action_num = action_num if method == "": - self.root = root + self.root = start_state if method == "UCT": - self.root = UCTNode(None, None, root, action_num, prior, mcts=self, inverse=inverse) + self.root = UCTNode(None, None, start_state, action_num, prior, mcts=self, inverse=inverse) if method == "TS": - self.root = TSNode(None, None, root, action_num, prior, inverse=inverse) + self.root = TSNode(None, None, start_state, action_num, prior, inverse=inverse) self.inverse = inverse # time spend on each step @@ -191,7 +192,7 @@ class MCTS(object): self.expansion_time += exp_time self.backpropagation_time += back_time step += 1 - if (self.debug): + if self.debug: file = open("mcts_profiling.txt", "a") file.write("[" + str(self.role) + "]" + " sel " + '%.3f' % self.selection_time + " "