minor fixed for mcts, check finish for go
This commit is contained in:
parent
430b78abf5
commit
3f238864fb
@ -212,11 +212,14 @@ class Go:
|
||||
def simulate_step_forward(self, state, action):
|
||||
# initialize the simulate_board from state
|
||||
history_boards, color = state
|
||||
vertex = self._action2vertex(action)
|
||||
new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex)
|
||||
history_boards.append(new_board)
|
||||
new_color = -color
|
||||
return [history_boards, new_color], 0
|
||||
if history_boards[-1] == history_boards[-2] and action is utils.PASS:
|
||||
return None, 2 * (float(self.executor_get_score(history_boards[-1]) > 0)-0.5) * color
|
||||
else:
|
||||
vertex = self._action2vertex(action)
|
||||
new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex)
|
||||
history_boards.append(new_board)
|
||||
new_color = -color
|
||||
return [history_boards, new_color], 0
|
||||
|
||||
def executor_do_move(self, history, latest_boards, current_board, color, vertex):
|
||||
if not self._rule_check(history, current_board, color, vertex):
|
||||
|
@ -38,6 +38,7 @@ class MCTSNode(object):
|
||||
def valid_mask(self, simulator):
|
||||
pass
|
||||
|
||||
|
||||
class UCTNode(MCTSNode):
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||
@ -71,10 +72,13 @@ class UCTNode(MCTSNode):
|
||||
self.parent.backpropagation(self.children[action].reward)
|
||||
|
||||
def valid_mask(self, simulator):
|
||||
# let all invalid actions be illeagel in mcts
|
||||
if self.mask is None:
|
||||
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
|
||||
self.ucb[self.mask] = -float("Inf")
|
||||
# let all invalid actions be illegal in mcts
|
||||
if not hasattr(simulator, 'simulate_get_mask'):
|
||||
pass
|
||||
else:
|
||||
if self.mask is None:
|
||||
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
|
||||
self.ucb[self.mask] = -float("Inf")
|
||||
|
||||
|
||||
class TSNode(MCTSNode):
|
||||
|
Loading…
x
Reference in New Issue
Block a user