test next_action.next_state in MCTS

This commit is contained in:
Dong Yan 2017-12-28 15:52:31 +08:00
parent 47676993fd
commit 08b6649fea
2 changed files with 8 additions and 9 deletions

View File

@ -156,8 +156,6 @@ class ResNet(object):
# Note : maybe we can use it for isolating test of MCTS
#prob = [1.0 / self.action_num] * self.action_num
#return [prob, np.random.uniform(-1, 1)]
if state is None:
return [[0.0] * self.action_num, 0]
history, color = state
if len(history) != self.history_length:
raise ValueError(

View File

@ -112,11 +112,8 @@ class ActionNode(object):
return self
def expansion(self, prior, action_num):
if self.next_state is not None:
# note that self.next_state was assigned already at the selection function
# self.next_state is None means MCTS selected a terminate node
self.children[self.next_state_hashable] = UCTNode(self, self.action, self.next_state, action_num, prior,
mcts=self.mcts, inverse=self.parent.inverse)
self.children[self.next_state_hashable] = UCTNode(self, self.action, self.next_state, action_num, prior,
mcts=self.mcts, inverse=self.parent.inverse)
def backpropagation(self, value):
self.reward += value
@ -183,8 +180,12 @@ class MCTS(object):
t0 = time.time()
next_action = self.root.selection(self.simulator)
t1 = time.time()
prior, value = self.evaluator(next_action.next_state)
next_action.expansion(prior, self.action_num)
# next_action.next_state is None means the parent state node of next_action is a terminate node
if next_action.next_state is not None:
prior, value = self.evaluator(next_action.next_state)
next_action.expansion(prior, self.action_num)
else:
value = 0
t2 = time.time()
if self.inverse:
next_action.backpropagation(-value + 0.)