Modification for backpropagation process
This commit is contained in:
parent
9f60984973
commit
8d102d249f
@ -198,7 +198,10 @@ class MCTS(object):
|
|||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
value = node.children[new_action].expansion(self.evaluator, self.action_num)
|
value = node.children[new_action].expansion(self.evaluator, self.action_num)
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
node.children[new_action].backpropagation(value + 0.)
|
if self.inverse:
|
||||||
|
node.children[new_action].backpropagation(-value + 0.)
|
||||||
|
else:
|
||||||
|
node.children[new_action].backpropagation(value + 0.)
|
||||||
t3 = time.time()
|
t3 = time.time()
|
||||||
return t1 - t0, t2 - t1, t3 - t2
|
return t1 - t0, t2 - t1, t3 - t2
|
||||||
|
|
||||||
|
@ -278,8 +278,12 @@ class MCTSVirtualLoss(object):
|
|||||||
priors[i],
|
priors[i],
|
||||||
nodes[i].inverse)
|
nodes[i].inverse)
|
||||||
|
|
||||||
for i in range(self.batch_size):
|
if self.inverse:
|
||||||
nodes[i].children[new_actions[i]].backpropagation(values[i] + 0.)
|
for i in range(self.batch_size):
|
||||||
|
nodes[i].children[new_actions[i]].backpropagation(-values[i] + 0.)
|
||||||
|
else:
|
||||||
|
for i in range(self.batch_size):
|
||||||
|
nodes[i].children[new_actions[i]].backpropagation(values[i] + 0.)
|
||||||
|
|
||||||
|
|
||||||
##### TODO
|
##### TODO
|
||||||
|
Loading…
x
Reference in New Issue
Block a user