forbid pass, if we have other choices

This commit is contained in:
Wenbo Hu 2017-12-20 22:10:47 +08:00
parent 0ab38743aa
commit 40909fa994
2 changed files with 10 additions and 10 deletions

View File

@ -183,16 +183,16 @@ class Go:
return True return True
def simulate_is_valid_list(self, state, action_set): def simulate_is_valid_list(self, state, action_set):
## find all the valid actions # find all the invalid actions
## if no action is valid, then pass invalid_action_list = []
valid_action_list = [] for action_candidate in action_set[:-1]:
for action_candidate in action_set: # go through all the actions excluding pass
if not self.simulate_is_valid(state, action_candidate): if not self.simulate_is_valid(state, action_candidate):
valid_action_list.append(action_candidate) invalid_action_list.append(action_candidate)
if not valid_action_list: if len(invalid_action_list) < len(action_set) - 1:
valid_action_list.append(utils.PASS) invalid_action_list.append(action_set[-1])
# if valid_action_set is a empty set, add pass # forbid pass, if we have other choices
return valid_action_list return invalid_action_list
def _do_move(self, board, color, vertex): def _do_move(self, board, color, vertex):
if vertex == utils.PASS: if vertex == utils.PASS:

View File

@ -71,7 +71,7 @@ class UCTNode(MCTSNode):
self.parent.backpropagation(self.children[action].reward) self.parent.backpropagation(self.children[action].reward)
def valid_mask(self, simulator): def valid_mask(self, simulator):
# let all invalid actions illeagel in mcts # let all invalid actions be illeagel in mcts
if self.mask is None: if self.mask is None:
self.mask = simulator.simulate_is_valid_list(self.state, range(self.action_num - 1)) self.mask = simulator.simulate_is_valid_list(self.state, range(self.action_num - 1))
self.ucb[self.mask] = -float("Inf") self.ucb[self.mask] = -float("Inf")