From 40909fa994c7fccc8391123a8463807968219d26 Mon Sep 17 00:00:00 2001 From: Wenbo Hu Date: Wed, 20 Dec 2017 22:10:47 +0800 Subject: [PATCH] forbid pass, if we have other choices --- AlphaGo/go.py | 18 +++++++++--------- tianshou/core/mcts/mcts.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/AlphaGo/go.py b/AlphaGo/go.py index cbbe07c..1dfbb29 100644 --- a/AlphaGo/go.py +++ b/AlphaGo/go.py @@ -183,16 +183,16 @@ class Go: return True def simulate_is_valid_list(self, state, action_set): - ## find all the valid actions - ## if no action is valid, then pass - valid_action_list = [] - for action_candidate in action_set: + # find all the invalid actions + invalid_action_list = [] + for action_candidate in action_set[:-1]: + # go through all the actions excluding pass if not self.simulate_is_valid(state, action_candidate): - valid_action_list.append(action_candidate) - if not valid_action_list: - valid_action_list.append(utils.PASS) - # if valid_action_set is a empty set, add pass - return valid_action_list + invalid_action_list.append(action_candidate) + if len(invalid_action_list) < len(action_set) - 1: + invalid_action_list.append(action_set[-1]) + # forbid pass, if we have other choices + return invalid_action_list def _do_move(self, board, color, vertex): if vertex == utils.PASS: diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index 5aca06a..7edac97 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -71,7 +71,7 @@ class UCTNode(MCTSNode): self.parent.backpropagation(self.children[action].reward) def valid_mask(self, simulator): - # let all invalid actions illeagel in mcts + # let all invalid actions be illeagel in mcts if self.mask is None: self.mask = simulator.simulate_is_valid_list(self.state, range(self.action_num - 1)) self.ucb[self.mask] = -float("Inf")