From 029ab199f4a8da3fd15897cd9f3ef830467ad578 Mon Sep 17 00:00:00 2001 From: Dong Yan Date: Tue, 26 Dec 2017 16:47:24 +0800 Subject: [PATCH] add softmax for mcts root node --- AlphaGo/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/AlphaGo/model.py b/AlphaGo/model.py index 704a034..dbfc5ca 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -124,6 +124,7 @@ class ResNet(object): h = residual_block(h, self.is_training) self.v = value_head(h, self.is_training) self.p = policy_head(h, self.is_training, self.action_num) + self.prob = tf.nn.softmax(self.p) self.value_loss = tf.reduce_mean(tf.square(self.z - self.v)) self.policy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.pi, logits=self.p)) @@ -161,7 +162,7 @@ class ResNet(object): 'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history), self.history_length)) state = self._history2state(history, color) - return self.sess.run([tf.nn.softmax(self.p), self.v], feed_dict={self.x: state, self.is_training: False}) + return self.sess.run([self.prob, self.v], feed_dict={self.x: state, self.is_training: False}) def _history2state(self, history, color): """