minor fixed for mcts, check finish for go
This commit is contained in:
parent
430b78abf5
commit
3f238864fb
@ -212,11 +212,14 @@ class Go:
|
|||||||
def simulate_step_forward(self, state, action):
|
def simulate_step_forward(self, state, action):
|
||||||
# initialize the simulate_board from state
|
# initialize the simulate_board from state
|
||||||
history_boards, color = state
|
history_boards, color = state
|
||||||
vertex = self._action2vertex(action)
|
if history_boards[-1] == history_boards[-2] and action is utils.PASS:
|
||||||
new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex)
|
return None, 2 * (float(self.executor_get_score(history_boards[-1]) > 0)-0.5) * color
|
||||||
history_boards.append(new_board)
|
else:
|
||||||
new_color = -color
|
vertex = self._action2vertex(action)
|
||||||
return [history_boards, new_color], 0
|
new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex)
|
||||||
|
history_boards.append(new_board)
|
||||||
|
new_color = -color
|
||||||
|
return [history_boards, new_color], 0
|
||||||
|
|
||||||
def executor_do_move(self, history, latest_boards, current_board, color, vertex):
|
def executor_do_move(self, history, latest_boards, current_board, color, vertex):
|
||||||
if not self._rule_check(history, current_board, color, vertex):
|
if not self._rule_check(history, current_board, color, vertex):
|
||||||
|
@ -38,6 +38,7 @@ class MCTSNode(object):
|
|||||||
def valid_mask(self, simulator):
|
def valid_mask(self, simulator):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UCTNode(MCTSNode):
|
class UCTNode(MCTSNode):
|
||||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||||
@ -71,10 +72,13 @@ class UCTNode(MCTSNode):
|
|||||||
self.parent.backpropagation(self.children[action].reward)
|
self.parent.backpropagation(self.children[action].reward)
|
||||||
|
|
||||||
def valid_mask(self, simulator):
|
def valid_mask(self, simulator):
|
||||||
# let all invalid actions be illeagel in mcts
|
# let all invalid actions be illegal in mcts
|
||||||
if self.mask is None:
|
if not hasattr(simulator, 'simulate_get_mask'):
|
||||||
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
|
pass
|
||||||
self.ucb[self.mask] = -float("Inf")
|
else:
|
||||||
|
if self.mask is None:
|
||||||
|
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
|
||||||
|
self.ucb[self.mask] = -float("Inf")
|
||||||
|
|
||||||
|
|
||||||
class TSNode(MCTSNode):
|
class TSNode(MCTSNode):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user