modify the network

This commit is contained in:
Tongzheng Ren 2017-11-05 16:47:01 +08:00
parent 937ee9f1e3
commit 2734af8530
2 changed files with 6 additions and 2 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
.idea
leela-zero
.pyc

View File

@ -50,7 +50,7 @@ p = policy_heads(h, is_training)
loss = tf.reduce_mean(tf.square(z-v)) - tf.reduce_mean(tf.multiply(pi, tf.log(tf.nn.softmax(p, 1))))
reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
total_loss = loss + reg
train_op = tf.train.RMSPropOptimizer(1e-2).minimize(total_loss)
train_op = tf.train.RMSPropOptimizer(1e-4).minimize(total_loss)
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
saver = tf.train.Saver(max_to_keep=10, var_list=var_list)
@ -105,3 +105,6 @@ def forward(board):
else:
raise ValueError("No model loaded")
return sess.run([p,v], feed_dict={x:board})
if __name__='main':
train()