From affd0319e283a26276e44c1359bcd72172751da5 Mon Sep 17 00:00:00 2001 From: Dong Yan Date: Wed, 27 Dec 2017 21:11:40 +0800 Subject: [PATCH] rewrite the selection fuction of UCTNode to return the action node instead of return the state node and next action --- tianshou/core/mcts/mcts.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index a1b0b3d..4c23809 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -108,14 +108,14 @@ class ActionNode(object): self.mcts.simulate_sf_time += time.time() - head if self.next_state is None: # next_state is None means that self.parent.state is the terminate state self.mcts.action_selection_time += time.time() - head - return self.parent, self.action + return self self.next_state_hashable = hashable_conversion(self.next_state) if self.next_state_hashable in self.children.keys(): # next state has already visited before self.mcts.action_selection_time += time.time() - head return self.children[self.next_state_hashable].selection(simulator) else: # next state is a new state never seen before self.mcts.action_selection_time += time.time() - head - return self.parent, self.action + return self def expansion(self, prior, action_num): if self.next_state is not None: @@ -191,15 +191,15 @@ class MCTS(object): def _expand(self): t0 = time.time() - node, new_action = self.root.selection(self.simulator) + next_action = self.root.selection(self.simulator) t1 = time.time() - prior, value = self.evaluator(node.children[new_action].next_state) - node.children[new_action].expansion(prior, self.action_num) + prior, value = self.evaluator(next_action.next_state) + next_action.expansion(prior, self.action_num) t2 = time.time() if self.inverse: - node.children[new_action].backpropagation(-value + 0.) + next_action.backpropagation(-value + 0.) else: - node.children[new_action].backpropagation(value + 0.) + next_action.backpropagation(value + 0.) t3 = time.time() return t1 - t0, t2 - t1, t3 - t2