diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index 7edac97..8bb5f06 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -73,7 +73,7 @@ class UCTNode(MCTSNode): def valid_mask(self, simulator): # let all invalid actions be illeagel in mcts if self.mask is None: - self.mask = simulator.simulate_is_valid_list(self.state, range(self.action_num - 1)) + self.mask = simulator.simulate_is_valid_list(self.state, range(self.action_num)) self.ucb[self.mask] = -float("Inf")