diff --git a/AlphaGo/model.py b/AlphaGo/model.py index c3bb9f0..6fde6e5 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -156,8 +156,6 @@ class ResNet(object): # Note : maybe we can use it for isolating test of MCTS #prob = [1.0 / self.action_num] * self.action_num #return [prob, np.random.uniform(-1, 1)] - if state is None: - return [[0.0] * self.action_num, 0] history, color = state if len(history) != self.history_length: raise ValueError( diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index 9625261..1251d05 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -112,11 +112,8 @@ class ActionNode(object): return self def expansion(self, prior, action_num): - if self.next_state is not None: - # note that self.next_state was assigned already at the selection function - # self.next_state is None means MCTS selected a terminate node - self.children[self.next_state_hashable] = UCTNode(self, self.action, self.next_state, action_num, prior, - mcts=self.mcts, inverse=self.parent.inverse) + self.children[self.next_state_hashable] = UCTNode(self, self.action, self.next_state, action_num, prior, + mcts=self.mcts, inverse=self.parent.inverse) def backpropagation(self, value): self.reward += value @@ -183,8 +180,12 @@ class MCTS(object): t0 = time.time() next_action = self.root.selection(self.simulator) t1 = time.time() - prior, value = self.evaluator(next_action.next_state) - next_action.expansion(prior, self.action_num) + # next_action.next_state is None means the parent state node of next_action is a terminate node + if next_action.next_state is not None: + prior, value = self.evaluator(next_action.next_state) + next_action.expansion(prior, self.action_num) + else: + value = 0 t2 = time.time() if self.inverse: next_action.backpropagation(-value + 0.)