From 3f238864fbfe20843900de12513aec75b8a59943 Mon Sep 17 00:00:00 2001 From: rtz19970824 <1289226405@qq.com> Date: Sat, 23 Dec 2017 15:58:06 +0800 Subject: [PATCH] minor fixed for mcts, check finish for go --- AlphaGo/go.py | 13 ++++++++----- tianshou/core/mcts/mcts.py | 12 ++++++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/AlphaGo/go.py b/AlphaGo/go.py index b819c08..fe2ab74 100644 --- a/AlphaGo/go.py +++ b/AlphaGo/go.py @@ -212,11 +212,14 @@ class Go: def simulate_step_forward(self, state, action): # initialize the simulate_board from state history_boards, color = state - vertex = self._action2vertex(action) - new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex) - history_boards.append(new_board) - new_color = -color - return [history_boards, new_color], 0 + if history_boards[-1] == history_boards[-2] and action is utils.PASS: + return None, 2 * (float(self.executor_get_score(history_boards[-1]) > 0)-0.5) * color + else: + vertex = self._action2vertex(action) + new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex) + history_boards.append(new_board) + new_color = -color + return [history_boards, new_color], 0 def executor_do_move(self, history, latest_boards, current_board, color, vertex): if not self._rule_check(history, current_board, color, vertex): diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index e8f3709..e99373c 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -38,6 +38,7 @@ class MCTSNode(object): def valid_mask(self, simulator): pass + class UCTNode(MCTSNode): def __init__(self, parent, action, state, action_num, prior, inverse=False): super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse) @@ -71,10 +72,13 @@ class UCTNode(MCTSNode): self.parent.backpropagation(self.children[action].reward) def valid_mask(self, simulator): - # let all invalid actions be illeagel in mcts - if self.mask is None: - self.mask = simulator.simulate_get_mask(self.state, range(self.action_num)) - self.ucb[self.mask] = -float("Inf") + # let all invalid actions be illegal in mcts + if not hasattr(simulator, 'simulate_get_mask'): + pass + else: + if self.mask is None: + self.mask = simulator.simulate_get_mask(self.state, range(self.action_num)) + self.ucb[self.mask] = -float("Inf") class TSNode(MCTSNode):