diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index 98ab056..f733f83 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -198,7 +198,10 @@ class MCTS(object): t1 = time.time() value = node.children[new_action].expansion(self.evaluator, self.action_num) t2 = time.time() - node.children[new_action].backpropagation(value + 0.) + if self.inverse: + node.children[new_action].backpropagation(-value + 0.) + else: + node.children[new_action].backpropagation(value + 0.) t3 = time.time() return t1 - t0, t2 - t1, t3 - t2 diff --git a/tianshou/core/mcts/mcts_virtual_loss.py b/tianshou/core/mcts/mcts_virtual_loss.py index f27d8a3..5826bd5 100644 --- a/tianshou/core/mcts/mcts_virtual_loss.py +++ b/tianshou/core/mcts/mcts_virtual_loss.py @@ -278,8 +278,12 @@ class MCTSVirtualLoss(object): priors[i], nodes[i].inverse) - for i in range(self.batch_size): - nodes[i].children[new_actions[i]].backpropagation(values[i] + 0.) + if self.inverse: + for i in range(self.batch_size): + nodes[i].children[new_actions[i]].backpropagation(-values[i] + 0.) + else: + for i in range(self.batch_size): + nodes[i].children[new_actions[i]].backpropagation(values[i] + 0.) ##### TODO