From 48e95a21eaeec6495a1bc5985c434d64d7447baf Mon Sep 17 00:00:00 2001 From: Wenbo Hu Date: Wed, 20 Dec 2017 21:35:35 +0800 Subject: [PATCH] simulator process a valid set, instead of a single action --- AlphaGo/go.py | 18 +++++++++++++++--- tianshou/core/mcts/mcts.py | 9 ++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/AlphaGo/go.py b/AlphaGo/go.py index 7196533..559b375 100644 --- a/AlphaGo/go.py +++ b/AlphaGo/go.py @@ -121,9 +121,9 @@ class Go: if self._is_eye(current_board, color, vertex): return False # forbid position on its own eye. - if self._is_game_finish(current_board, color) and vertex == utils.PASS - return False - # forbid pass if the game is not finished. + #if self._is_game_finish(current_board, color) and vertex == utils.PASS + # return False + # forbid pass if the game is not finished. return True @@ -183,6 +183,18 @@ class Go: return True + def simulate_is_valid_list(self, state, action_set): + ## find all the valid actions + ## if no action is valid, then pass + valid_action_set = [] + for action_candidate in action_set: + if self.simulate_is_valid(self, state, action_candidate) + valid_action_set.append(action_candidate) + if not valid_action_set: + valid_action_set.append(utils.PASS) + # if valid_action_set is a empty set, add pass + return valid_action_set + def _do_move(self, board, color, vertex): if vertex == utils.PASS: return board diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index fac00fb..c14496d 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -72,13 +72,8 @@ class UCTNode(MCTSNode): def valid_mask(self, simulator): if self.mask is None: - self.mask = [] - for act in range(self.action_num - 1): - if not simulator.simulate_is_valid(self.state, act): - self.mask.append(act) - self.ucb[act] = -float("Inf") - else: - self.ucb[self.mask] = -float("Inf") + self.mask = simulator.simulate_is_valid_list(self.state, range(self.action_num - 1)) + self.ucb[self.mask] = -float("Inf") class TSNode(MCTSNode):