minor modification

This commit is contained in:
Tongzheng Ren 2017-11-05 16:37:26 +08:00
parent 937ee9f1e3
commit f09ebc2124
3 changed files with 5 additions and 2 deletions

View File

@ -50,7 +50,7 @@ p = policy_heads(h, is_training)
loss = tf.reduce_mean(tf.square(z-v)) - tf.reduce_mean(tf.multiply(pi, tf.log(tf.nn.softmax(p, 1)))) loss = tf.reduce_mean(tf.square(z-v)) - tf.reduce_mean(tf.multiply(pi, tf.log(tf.nn.softmax(p, 1))))
reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
total_loss = loss + reg total_loss = loss + reg
train_op = tf.train.RMSPropOptimizer(1e-2).minimize(total_loss) train_op = tf.train.RMSPropOptimizer(1e-4).minimize(total_loss)
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
saver = tf.train.Saver(max_to_keep=10, var_list=var_list) saver = tf.train.Saver(max_to_keep=10, var_list=var_list)
@ -105,3 +105,6 @@ def forward(board):
else: else:
raise ValueError("No model loaded") raise ValueError("No model loaded")
return sess.run([p,v], feed_dict={x:board}) return sess.run([p,v], feed_dict={x:board})
if __name__ == "__main__":
train()

BIN
AlphaGo/Network.pyc Normal file

Binary file not shown.

BIN
AlphaGo/multi_gpu.pyc Normal file

Binary file not shown.