diff --git a/.gitignore b/.gitignore index b9ae745..85c32a8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .idea leela-zero +.pyc diff --git a/AlphaGo/Network.py b/AlphaGo/Network.py index f594be2..ef77e21 100644 --- a/AlphaGo/Network.py +++ b/AlphaGo/Network.py @@ -50,7 +50,7 @@ p = policy_heads(h, is_training) loss = tf.reduce_mean(tf.square(z-v)) - tf.reduce_mean(tf.multiply(pi, tf.log(tf.nn.softmax(p, 1)))) reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) total_loss = loss + reg -train_op = tf.train.RMSPropOptimizer(1e-2).minimize(total_loss) +train_op = tf.train.RMSPropOptimizer(1e-4).minimize(total_loss) var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) saver = tf.train.Saver(max_to_keep=10, var_list=var_list) @@ -104,4 +104,7 @@ def forward(board): saver.restore(sess, ckpt_file) else: raise ValueError("No model loaded") - return sess.run([p,v], feed_dict={x:board}) \ No newline at end of file + return sess.run([p,v], feed_dict={x:board}) + +if __name__='main': + train()