From 2734af85302e9330bdb29f1face1ce624c2354f9 Mon Sep 17 00:00:00 2001 From: Tongzheng Ren Date: Sun, 5 Nov 2017 16:47:01 +0800 Subject: [PATCH] modify the network --- .gitignore | 1 + AlphaGo/Network.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index b9ae745..85c32a8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .idea leela-zero +.pyc diff --git a/AlphaGo/Network.py b/AlphaGo/Network.py index f594be2..ef77e21 100644 --- a/AlphaGo/Network.py +++ b/AlphaGo/Network.py @@ -50,7 +50,7 @@ p = policy_heads(h, is_training) loss = tf.reduce_mean(tf.square(z-v)) - tf.reduce_mean(tf.multiply(pi, tf.log(tf.nn.softmax(p, 1)))) reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) total_loss = loss + reg -train_op = tf.train.RMSPropOptimizer(1e-2).minimize(total_loss) +train_op = tf.train.RMSPropOptimizer(1e-4).minimize(total_loss) var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) saver = tf.train.Saver(max_to_keep=10, var_list=var_list) @@ -104,4 +104,7 @@ def forward(board): saver.restore(sess, ckpt_file) else: raise ValueError("No model loaded") - return sess.run([p,v], feed_dict={x:board}) \ No newline at end of file + return sess.run([p,v], feed_dict={x:board}) + +if __name__='main': + train()