From d48982d59ed2ca797d07ef6afee5f211a7e22aed Mon Sep 17 00:00:00 2001 From: Dong Yan Date: Wed, 27 Dec 2017 20:49:54 +0800 Subject: [PATCH] move evaluator from action node to mcts --- AlphaGo/model.py | 2 ++ tianshou/core/mcts/mcts.py | 10 ++++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/AlphaGo/model.py b/AlphaGo/model.py index 6fde6e5..c3bb9f0 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -156,6 +156,8 @@ class ResNet(object): # Note : maybe we can use it for isolating test of MCTS #prob = [1.0 / self.action_num] * self.action_num #return [prob, np.random.uniform(-1, 1)] + if state is None: + return [[0.0] * self.action_num, 0] history, color = state if len(history) != self.history_length: raise ValueError( diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index f733f83..a1b0b3d 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -117,15 +117,12 @@ class ActionNode(object): self.mcts.action_selection_time += time.time() - head return self.parent, self.action - def expansion(self, evaluator, action_num): + def expansion(self, prior, action_num): if self.next_state is not None: # note that self.next_state was assigned already at the selection function - prior, value = evaluator(self.next_state) + # self.next_state is None means MCTS selected a terminate node self.children[self.next_state_hashable] = UCTNode(self, self.action, self.next_state, action_num, prior, mcts=self.mcts, inverse=self.parent.inverse) - return value - else: # self.next_state is None means MCTS selected a terminate node - return 0. def backpropagation(self, value): self.reward += value @@ -196,7 +193,8 @@ class MCTS(object): t0 = time.time() node, new_action = self.root.selection(self.simulator) t1 = time.time() - value = node.children[new_action].expansion(self.evaluator, self.action_num) + prior, value = self.evaluator(node.children[new_action].next_state) + node.children[new_action].expansion(prior, self.action_num) t2 = time.time() if self.inverse: node.children[new_action].backpropagation(-value + 0.)