rewrite the selection fuction of UCTNode to return the action node instead of return the state node and next action
This commit is contained in:
parent
d48982d59e
commit
affd0319e2
@ -108,14 +108,14 @@ class ActionNode(object):
|
|||||||
self.mcts.simulate_sf_time += time.time() - head
|
self.mcts.simulate_sf_time += time.time() - head
|
||||||
if self.next_state is None: # next_state is None means that self.parent.state is the terminate state
|
if self.next_state is None: # next_state is None means that self.parent.state is the terminate state
|
||||||
self.mcts.action_selection_time += time.time() - head
|
self.mcts.action_selection_time += time.time() - head
|
||||||
return self.parent, self.action
|
return self
|
||||||
self.next_state_hashable = hashable_conversion(self.next_state)
|
self.next_state_hashable = hashable_conversion(self.next_state)
|
||||||
if self.next_state_hashable in self.children.keys(): # next state has already visited before
|
if self.next_state_hashable in self.children.keys(): # next state has already visited before
|
||||||
self.mcts.action_selection_time += time.time() - head
|
self.mcts.action_selection_time += time.time() - head
|
||||||
return self.children[self.next_state_hashable].selection(simulator)
|
return self.children[self.next_state_hashable].selection(simulator)
|
||||||
else: # next state is a new state never seen before
|
else: # next state is a new state never seen before
|
||||||
self.mcts.action_selection_time += time.time() - head
|
self.mcts.action_selection_time += time.time() - head
|
||||||
return self.parent, self.action
|
return self
|
||||||
|
|
||||||
def expansion(self, prior, action_num):
|
def expansion(self, prior, action_num):
|
||||||
if self.next_state is not None:
|
if self.next_state is not None:
|
||||||
@ -191,15 +191,15 @@ class MCTS(object):
|
|||||||
|
|
||||||
def _expand(self):
|
def _expand(self):
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
node, new_action = self.root.selection(self.simulator)
|
next_action = self.root.selection(self.simulator)
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
prior, value = self.evaluator(node.children[new_action].next_state)
|
prior, value = self.evaluator(next_action.next_state)
|
||||||
node.children[new_action].expansion(prior, self.action_num)
|
next_action.expansion(prior, self.action_num)
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
if self.inverse:
|
if self.inverse:
|
||||||
node.children[new_action].backpropagation(-value + 0.)
|
next_action.backpropagation(-value + 0.)
|
||||||
else:
|
else:
|
||||||
node.children[new_action].backpropagation(value + 0.)
|
next_action.backpropagation(value + 0.)
|
||||||
t3 = time.time()
|
t3 = time.time()
|
||||||
return t1 - t0, t2 - t1, t3 - t2
|
return t1 - t0, t2 - t1, t3 - t2
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user