Merge branch 'master' of https://github.com/sproblvem/tianshou
This commit is contained in:
commit
0a160065aa
@ -198,7 +198,10 @@ class MCTS(object):
|
||||
t1 = time.time()
|
||||
value = node.children[new_action].expansion(self.evaluator, self.action_num)
|
||||
t2 = time.time()
|
||||
node.children[new_action].backpropagation(value + 0.)
|
||||
if self.inverse:
|
||||
node.children[new_action].backpropagation(-value + 0.)
|
||||
else:
|
||||
node.children[new_action].backpropagation(value + 0.)
|
||||
t3 = time.time()
|
||||
return t1 - t0, t2 - t1, t3 - t2
|
||||
|
||||
|
@ -278,8 +278,12 @@ class MCTSVirtualLoss(object):
|
||||
priors[i],
|
||||
nodes[i].inverse)
|
||||
|
||||
for i in range(self.batch_size):
|
||||
nodes[i].children[new_actions[i]].backpropagation(values[i] + 0.)
|
||||
if self.inverse:
|
||||
for i in range(self.batch_size):
|
||||
nodes[i].children[new_actions[i]].backpropagation(-values[i] + 0.)
|
||||
else:
|
||||
for i in range(self.batch_size):
|
||||
nodes[i].children[new_actions[i]].backpropagation(values[i] + 0.)
|
||||
|
||||
|
||||
##### TODO
|
||||
|
Loading…
x
Reference in New Issue
Block a user