minor fixed for mcts, check finish for go

This commit is contained in:
rtz19970824 2017-12-23 15:58:06 +08:00
parent 430b78abf5
commit 3f238864fb
2 changed files with 16 additions and 9 deletions

View File

@ -212,11 +212,14 @@ class Go:
def simulate_step_forward(self, state, action):
# initialize the simulate_board from state
history_boards, color = state
vertex = self._action2vertex(action)
new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex)
history_boards.append(new_board)
new_color = -color
return [history_boards, new_color], 0
if history_boards[-1] == history_boards[-2] and action is utils.PASS:
return None, 2 * (float(self.executor_get_score(history_boards[-1]) > 0)-0.5) * color
else:
vertex = self._action2vertex(action)
new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex)
history_boards.append(new_board)
new_color = -color
return [history_boards, new_color], 0
def executor_do_move(self, history, latest_boards, current_board, color, vertex):
if not self._rule_check(history, current_board, color, vertex):

View File

@ -38,6 +38,7 @@ class MCTSNode(object):
def valid_mask(self, simulator):
pass
class UCTNode(MCTSNode):
def __init__(self, parent, action, state, action_num, prior, inverse=False):
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
@ -71,10 +72,13 @@ class UCTNode(MCTSNode):
self.parent.backpropagation(self.children[action].reward)
def valid_mask(self, simulator):
# let all invalid actions be illeagel in mcts
if self.mask is None:
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
self.ucb[self.mask] = -float("Inf")
# let all invalid actions be illegal in mcts
if not hasattr(simulator, 'simulate_get_mask'):
pass
else:
if self.mask is None:
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
self.ucb[self.mask] = -float("Inf")
class TSNode(MCTSNode):