This commit is contained in:
rtz19970824 2017-12-27 19:54:52 +08:00
commit 0a160065aa
2 changed files with 10 additions and 3 deletions

View File

@ -198,6 +198,9 @@ class MCTS(object):
t1 = time.time() t1 = time.time()
value = node.children[new_action].expansion(self.evaluator, self.action_num) value = node.children[new_action].expansion(self.evaluator, self.action_num)
t2 = time.time() t2 = time.time()
if self.inverse:
node.children[new_action].backpropagation(-value + 0.)
else:
node.children[new_action].backpropagation(value + 0.) node.children[new_action].backpropagation(value + 0.)
t3 = time.time() t3 = time.time()
return t1 - t0, t2 - t1, t3 - t2 return t1 - t0, t2 - t1, t3 - t2

View File

@ -278,6 +278,10 @@ class MCTSVirtualLoss(object):
priors[i], priors[i],
nodes[i].inverse) nodes[i].inverse)
if self.inverse:
for i in range(self.batch_size):
nodes[i].children[new_actions[i]].backpropagation(-values[i] + 0.)
else:
for i in range(self.batch_size): for i in range(self.batch_size):
nodes[i].children[new_actions[i]].backpropagation(values[i] + 0.) nodes[i].children[new_actions[i]].backpropagation(values[i] + 0.)