minor fixed for mcts, check finish for go
This commit is contained in:
parent
d01f8cd210
commit
c50ee8f029
@ -212,6 +212,9 @@ class Go:
|
|||||||
def simulate_step_forward(self, state, action):
|
def simulate_step_forward(self, state, action):
|
||||||
# initialize the simulate_board from state
|
# initialize the simulate_board from state
|
||||||
history_boards, color = state
|
history_boards, color = state
|
||||||
|
if history_boards[-1] == history_boards[-2] and action is utils.PASS:
|
||||||
|
return None, 2 * (float(self.executor_get_score(history_boards[-1]) > 0)-0.5) * color
|
||||||
|
else:
|
||||||
vertex = self._action2vertex(action)
|
vertex = self._action2vertex(action)
|
||||||
new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex)
|
new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex)
|
||||||
history_boards.append(new_board)
|
history_boards.append(new_board)
|
||||||
|
@ -38,6 +38,7 @@ class MCTSNode(object):
|
|||||||
def valid_mask(self, simulator):
|
def valid_mask(self, simulator):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UCTNode(MCTSNode):
|
class UCTNode(MCTSNode):
|
||||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||||
@ -71,7 +72,10 @@ class UCTNode(MCTSNode):
|
|||||||
self.parent.backpropagation(self.children[action].reward)
|
self.parent.backpropagation(self.children[action].reward)
|
||||||
|
|
||||||
def valid_mask(self, simulator):
|
def valid_mask(self, simulator):
|
||||||
# let all invalid actions be illeagel in mcts
|
# let all invalid actions be illegal in mcts
|
||||||
|
if not hasattr(simulator, 'simulate_get_mask'):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
if self.mask is None:
|
if self.mask is None:
|
||||||
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
|
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
|
||||||
self.ucb[self.mask] = -float("Inf")
|
self.ucb[self.mask] = -float("Inf")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user