rewrite the selection fuction of UCTNode to return the action node instead of return the state node and next action

This commit is contained in:
Dong Yan 2017-12-27 21:11:40 +08:00
parent d48982d59e
commit affd0319e2

View File

@ -108,14 +108,14 @@ class ActionNode(object):
self.mcts.simulate_sf_time += time.time() - head self.mcts.simulate_sf_time += time.time() - head
if self.next_state is None: # next_state is None means that self.parent.state is the terminate state if self.next_state is None: # next_state is None means that self.parent.state is the terminate state
self.mcts.action_selection_time += time.time() - head self.mcts.action_selection_time += time.time() - head
return self.parent, self.action return self
self.next_state_hashable = hashable_conversion(self.next_state) self.next_state_hashable = hashable_conversion(self.next_state)
if self.next_state_hashable in self.children.keys(): # next state has already visited before if self.next_state_hashable in self.children.keys(): # next state has already visited before
self.mcts.action_selection_time += time.time() - head self.mcts.action_selection_time += time.time() - head
return self.children[self.next_state_hashable].selection(simulator) return self.children[self.next_state_hashable].selection(simulator)
else: # next state is a new state never seen before else: # next state is a new state never seen before
self.mcts.action_selection_time += time.time() - head self.mcts.action_selection_time += time.time() - head
return self.parent, self.action return self
def expansion(self, prior, action_num): def expansion(self, prior, action_num):
if self.next_state is not None: if self.next_state is not None:
@ -191,15 +191,15 @@ class MCTS(object):
def _expand(self): def _expand(self):
t0 = time.time() t0 = time.time()
node, new_action = self.root.selection(self.simulator) next_action = self.root.selection(self.simulator)
t1 = time.time() t1 = time.time()
prior, value = self.evaluator(node.children[new_action].next_state) prior, value = self.evaluator(next_action.next_state)
node.children[new_action].expansion(prior, self.action_num) next_action.expansion(prior, self.action_num)
t2 = time.time() t2 = time.time()
if self.inverse: if self.inverse:
node.children[new_action].backpropagation(-value + 0.) next_action.backpropagation(-value + 0.)
else: else:
node.children[new_action].backpropagation(value + 0.) next_action.backpropagation(value + 0.)
t3 = time.time() t3 = time.time()
return t1 - t0, t2 - t1, t3 - t2 return t1 - t0, t2 - t1, t3 - t2