move evaluator from action node to mcts
This commit is contained in:
parent
0a160065aa
commit
d48982d59e
@ -156,6 +156,8 @@ class ResNet(object):
|
|||||||
# Note : maybe we can use it for isolating test of MCTS
|
# Note : maybe we can use it for isolating test of MCTS
|
||||||
#prob = [1.0 / self.action_num] * self.action_num
|
#prob = [1.0 / self.action_num] * self.action_num
|
||||||
#return [prob, np.random.uniform(-1, 1)]
|
#return [prob, np.random.uniform(-1, 1)]
|
||||||
|
if state is None:
|
||||||
|
return [[0.0] * self.action_num, 0]
|
||||||
history, color = state
|
history, color = state
|
||||||
if len(history) != self.history_length:
|
if len(history) != self.history_length:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -117,15 +117,12 @@ class ActionNode(object):
|
|||||||
self.mcts.action_selection_time += time.time() - head
|
self.mcts.action_selection_time += time.time() - head
|
||||||
return self.parent, self.action
|
return self.parent, self.action
|
||||||
|
|
||||||
def expansion(self, evaluator, action_num):
|
def expansion(self, prior, action_num):
|
||||||
if self.next_state is not None:
|
if self.next_state is not None:
|
||||||
# note that self.next_state was assigned already at the selection function
|
# note that self.next_state was assigned already at the selection function
|
||||||
prior, value = evaluator(self.next_state)
|
# self.next_state is None means MCTS selected a terminate node
|
||||||
self.children[self.next_state_hashable] = UCTNode(self, self.action, self.next_state, action_num, prior,
|
self.children[self.next_state_hashable] = UCTNode(self, self.action, self.next_state, action_num, prior,
|
||||||
mcts=self.mcts, inverse=self.parent.inverse)
|
mcts=self.mcts, inverse=self.parent.inverse)
|
||||||
return value
|
|
||||||
else: # self.next_state is None means MCTS selected a terminate node
|
|
||||||
return 0.
|
|
||||||
|
|
||||||
def backpropagation(self, value):
|
def backpropagation(self, value):
|
||||||
self.reward += value
|
self.reward += value
|
||||||
@ -196,7 +193,8 @@ class MCTS(object):
|
|||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
node, new_action = self.root.selection(self.simulator)
|
node, new_action = self.root.selection(self.simulator)
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
value = node.children[new_action].expansion(self.evaluator, self.action_num)
|
prior, value = self.evaluator(node.children[new_action].next_state)
|
||||||
|
node.children[new_action].expansion(prior, self.action_num)
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
if self.inverse:
|
if self.inverse:
|
||||||
node.children[new_action].backpropagation(-value + 0.)
|
node.children[new_action].backpropagation(-value + 0.)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user