move evaluator from action node to mcts

This commit is contained in:
Dong Yan 2017-12-27 20:49:54 +08:00
parent 0a160065aa
commit d48982d59e
2 changed files with 6 additions and 6 deletions

View File

@ -156,6 +156,8 @@ class ResNet(object):
# Note : maybe we can use it for isolating test of MCTS # Note : maybe we can use it for isolating test of MCTS
#prob = [1.0 / self.action_num] * self.action_num #prob = [1.0 / self.action_num] * self.action_num
#return [prob, np.random.uniform(-1, 1)] #return [prob, np.random.uniform(-1, 1)]
if state is None:
return [[0.0] * self.action_num, 0]
history, color = state history, color = state
if len(history) != self.history_length: if len(history) != self.history_length:
raise ValueError( raise ValueError(

View File

@ -117,15 +117,12 @@ class ActionNode(object):
self.mcts.action_selection_time += time.time() - head self.mcts.action_selection_time += time.time() - head
return self.parent, self.action return self.parent, self.action
def expansion(self, evaluator, action_num): def expansion(self, prior, action_num):
if self.next_state is not None: if self.next_state is not None:
# note that self.next_state was assigned already at the selection function # note that self.next_state was assigned already at the selection function
prior, value = evaluator(self.next_state) # self.next_state is None means MCTS selected a terminate node
self.children[self.next_state_hashable] = UCTNode(self, self.action, self.next_state, action_num, prior, self.children[self.next_state_hashable] = UCTNode(self, self.action, self.next_state, action_num, prior,
mcts=self.mcts, inverse=self.parent.inverse) mcts=self.mcts, inverse=self.parent.inverse)
return value
else: # self.next_state is None means MCTS selected a terminate node
return 0.
def backpropagation(self, value): def backpropagation(self, value):
self.reward += value self.reward += value
@ -196,7 +193,8 @@ class MCTS(object):
t0 = time.time() t0 = time.time()
node, new_action = self.root.selection(self.simulator) node, new_action = self.root.selection(self.simulator)
t1 = time.time() t1 = time.time()
value = node.children[new_action].expansion(self.evaluator, self.action_num) prior, value = self.evaluator(node.children[new_action].next_state)
node.children[new_action].expansion(prior, self.action_num)
t2 = time.time() t2 = time.time()
if self.inverse: if self.inverse:
node.children[new_action].backpropagation(-value + 0.) node.children[new_action].backpropagation(-value + 0.)