minor fixed for mcts, check finish for go

This commit is contained in:
rtz19970824 2017-12-23 15:58:06 +08:00
parent 430b78abf5
commit 3f238864fb
2 changed files with 16 additions and 9 deletions

View File

@ -212,6 +212,9 @@ class Go:
def simulate_step_forward(self, state, action):
# initialize the simulate_board from state
history_boards, color = state
if history_boards[-1] == history_boards[-2] and action is utils.PASS:
return None, 2 * (float(self.executor_get_score(history_boards[-1]) > 0)-0.5) * color
else:
vertex = self._action2vertex(action)
new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex)
history_boards.append(new_board)

View File

@ -38,6 +38,7 @@ class MCTSNode(object):
def valid_mask(self, simulator):
pass
class UCTNode(MCTSNode):
def __init__(self, parent, action, state, action_num, prior, inverse=False):
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
@ -71,7 +72,10 @@ class UCTNode(MCTSNode):
self.parent.backpropagation(self.children[action].reward)
def valid_mask(self, simulator):
# let all invalid actions be illeagel in mcts
# let all invalid actions be illegal in mcts
if not hasattr(simulator, 'simulate_get_mask'):
pass
else:
if self.mask is None:
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
self.ucb[self.mask] = -float("Inf")