Modification for backpropagation process

This commit is contained in:
JialianLee 2017-12-27 18:55:00 +08:00
parent 9f60984973
commit 8d102d249f
2 changed files with 10 additions and 3 deletions

View File

@ -198,7 +198,10 @@ class MCTS(object):
t1 = time.time()
value = node.children[new_action].expansion(self.evaluator, self.action_num)
t2 = time.time()
node.children[new_action].backpropagation(value + 0.)
if self.inverse:
node.children[new_action].backpropagation(-value + 0.)
else:
node.children[new_action].backpropagation(value + 0.)
t3 = time.time()
return t1 - t0, t2 - t1, t3 - t2

View File

@ -278,8 +278,12 @@ class MCTSVirtualLoss(object):
priors[i],
nodes[i].inverse)
for i in range(self.batch_size):
nodes[i].children[new_actions[i]].backpropagation(values[i] + 0.)
if self.inverse:
for i in range(self.batch_size):
nodes[i].children[new_actions[i]].backpropagation(-values[i] + 0.)
else:
for i in range(self.batch_size):
nodes[i].children[new_actions[i]].backpropagation(values[i] + 0.)
##### TODO