add softmax for mcts root node
This commit is contained in:
parent
8f508c790b
commit
029ab199f4
@ -124,6 +124,7 @@ class ResNet(object):
|
|||||||
h = residual_block(h, self.is_training)
|
h = residual_block(h, self.is_training)
|
||||||
self.v = value_head(h, self.is_training)
|
self.v = value_head(h, self.is_training)
|
||||||
self.p = policy_head(h, self.is_training, self.action_num)
|
self.p = policy_head(h, self.is_training, self.action_num)
|
||||||
|
self.prob = tf.nn.softmax(self.p)
|
||||||
self.value_loss = tf.reduce_mean(tf.square(self.z - self.v))
|
self.value_loss = tf.reduce_mean(tf.square(self.z - self.v))
|
||||||
self.policy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.pi, logits=self.p))
|
self.policy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.pi, logits=self.p))
|
||||||
|
|
||||||
@ -161,7 +162,7 @@ class ResNet(object):
|
|||||||
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
|
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
|
||||||
self.history_length))
|
self.history_length))
|
||||||
state = self._history2state(history, color)
|
state = self._history2state(history, color)
|
||||||
return self.sess.run([tf.nn.softmax(self.p), self.v], feed_dict={self.x: state, self.is_training: False})
|
return self.sess.run([self.prob, self.v], feed_dict={self.x: state, self.is_training: False})
|
||||||
|
|
||||||
def _history2state(self, history, color):
|
def _history2state(self, history, color):
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user