rewrite the selection fuction of UCTNode to return the action node instead of return the state node and next action
This commit is contained in:
parent
d48982d59e
commit
affd0319e2
@ -108,14 +108,14 @@ class ActionNode(object):
|
||||
self.mcts.simulate_sf_time += time.time() - head
|
||||
if self.next_state is None: # next_state is None means that self.parent.state is the terminate state
|
||||
self.mcts.action_selection_time += time.time() - head
|
||||
return self.parent, self.action
|
||||
return self
|
||||
self.next_state_hashable = hashable_conversion(self.next_state)
|
||||
if self.next_state_hashable in self.children.keys(): # next state has already visited before
|
||||
self.mcts.action_selection_time += time.time() - head
|
||||
return self.children[self.next_state_hashable].selection(simulator)
|
||||
else: # next state is a new state never seen before
|
||||
self.mcts.action_selection_time += time.time() - head
|
||||
return self.parent, self.action
|
||||
return self
|
||||
|
||||
def expansion(self, prior, action_num):
|
||||
if self.next_state is not None:
|
||||
@ -191,15 +191,15 @@ class MCTS(object):
|
||||
|
||||
def _expand(self):
|
||||
t0 = time.time()
|
||||
node, new_action = self.root.selection(self.simulator)
|
||||
next_action = self.root.selection(self.simulator)
|
||||
t1 = time.time()
|
||||
prior, value = self.evaluator(node.children[new_action].next_state)
|
||||
node.children[new_action].expansion(prior, self.action_num)
|
||||
prior, value = self.evaluator(next_action.next_state)
|
||||
next_action.expansion(prior, self.action_num)
|
||||
t2 = time.time()
|
||||
if self.inverse:
|
||||
node.children[new_action].backpropagation(-value + 0.)
|
||||
next_action.backpropagation(-value + 0.)
|
||||
else:
|
||||
node.children[new_action].backpropagation(value + 0.)
|
||||
next_action.backpropagation(value + 0.)
|
||||
t3 = time.time()
|
||||
return t1 - t0, t2 - t1, t3 - t2
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user