rewrite the selection fuction of UCTNode to return the action node instead of return the state node and next action

This commit is contained in:
Dong Yan 2017-12-27 21:11:40 +08:00
parent d48982d59e
commit affd0319e2

View File

@ -108,14 +108,14 @@ class ActionNode(object):
self.mcts.simulate_sf_time += time.time() - head
if self.next_state is None: # next_state is None means that self.parent.state is the terminate state
self.mcts.action_selection_time += time.time() - head
return self.parent, self.action
return self
self.next_state_hashable = hashable_conversion(self.next_state)
if self.next_state_hashable in self.children.keys(): # next state has already visited before
self.mcts.action_selection_time += time.time() - head
return self.children[self.next_state_hashable].selection(simulator)
else: # next state is a new state never seen before
self.mcts.action_selection_time += time.time() - head
return self.parent, self.action
return self
def expansion(self, prior, action_num):
if self.next_state is not None:
@ -191,15 +191,15 @@ class MCTS(object):
def _expand(self):
t0 = time.time()
node, new_action = self.root.selection(self.simulator)
next_action = self.root.selection(self.simulator)
t1 = time.time()
prior, value = self.evaluator(node.children[new_action].next_state)
node.children[new_action].expansion(prior, self.action_num)
prior, value = self.evaluator(next_action.next_state)
next_action.expansion(prior, self.action_num)
t2 = time.time()
if self.inverse:
node.children[new_action].backpropagation(-value + 0.)
next_action.backpropagation(-value + 0.)
else:
node.children[new_action].backpropagation(value + 0.)
next_action.backpropagation(value + 0.)
t3 = time.time()
return t1 - t0, t2 - t1, t3 - t2