From 8d102d249fd05a274f8d4174d061b0fc046181cb Mon Sep 17 00:00:00 2001 From: JialianLee Date: Wed, 27 Dec 2017 18:55:00 +0800 Subject: [PATCH] Modification for backpropagation process --- tianshou/core/mcts/mcts.py | 5 ++++- tianshou/core/mcts/mcts_virtual_loss.py | 8 ++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index 98ab056..f733f83 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -198,7 +198,10 @@ class MCTS(object): t1 = time.time() value = node.children[new_action].expansion(self.evaluator, self.action_num) t2 = time.time() - node.children[new_action].backpropagation(value + 0.) + if self.inverse: + node.children[new_action].backpropagation(-value + 0.) + else: + node.children[new_action].backpropagation(value + 0.) t3 = time.time() return t1 - t0, t2 - t1, t3 - t2 diff --git a/tianshou/core/mcts/mcts_virtual_loss.py b/tianshou/core/mcts/mcts_virtual_loss.py index f27d8a3..5826bd5 100644 --- a/tianshou/core/mcts/mcts_virtual_loss.py +++ b/tianshou/core/mcts/mcts_virtual_loss.py @@ -278,8 +278,12 @@ class MCTSVirtualLoss(object): priors[i], nodes[i].inverse) - for i in range(self.batch_size): - nodes[i].children[new_actions[i]].backpropagation(values[i] + 0.) + if self.inverse: + for i in range(self.batch_size): + nodes[i].children[new_actions[i]].backpropagation(-values[i] + 0.) + else: + for i in range(self.batch_size): + nodes[i].children[new_actions[i]].backpropagation(values[i] + 0.) ##### TODO