add softmax for mcts root node

This commit is contained in:
Dong Yan 2017-12-26 16:47:24 +08:00
parent 8f508c790b
commit 029ab199f4

View File

@ -124,6 +124,7 @@ class ResNet(object):
h = residual_block(h, self.is_training) h = residual_block(h, self.is_training)
self.v = value_head(h, self.is_training) self.v = value_head(h, self.is_training)
self.p = policy_head(h, self.is_training, self.action_num) self.p = policy_head(h, self.is_training, self.action_num)
self.prob = tf.nn.softmax(self.p)
self.value_loss = tf.reduce_mean(tf.square(self.z - self.v)) self.value_loss = tf.reduce_mean(tf.square(self.z - self.v))
self.policy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.pi, logits=self.p)) self.policy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.pi, logits=self.p))
@ -161,7 +162,7 @@ class ResNet(object):
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history), 'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
self.history_length)) self.history_length))
state = self._history2state(history, color) state = self._history2state(history, color)
return self.sess.run([tf.nn.softmax(self.p), self.v], feed_dict={self.x: state, self.is_training: False}) return self.sess.run([self.prob, self.v], feed_dict={self.x: state, self.is_training: False})
def _history2state(self, history, color): def _history2state(self, history, color):
""" """