modify the network
This commit is contained in:
parent
937ee9f1e3
commit
2734af8530
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
|||||||
.idea
|
.idea
|
||||||
leela-zero
|
leela-zero
|
||||||
|
.pyc
|
||||||
|
@ -50,7 +50,7 @@ p = policy_heads(h, is_training)
|
|||||||
loss = tf.reduce_mean(tf.square(z-v)) - tf.reduce_mean(tf.multiply(pi, tf.log(tf.nn.softmax(p, 1))))
|
loss = tf.reduce_mean(tf.square(z-v)) - tf.reduce_mean(tf.multiply(pi, tf.log(tf.nn.softmax(p, 1))))
|
||||||
reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
|
reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
|
||||||
total_loss = loss + reg
|
total_loss = loss + reg
|
||||||
train_op = tf.train.RMSPropOptimizer(1e-2).minimize(total_loss)
|
train_op = tf.train.RMSPropOptimizer(1e-4).minimize(total_loss)
|
||||||
|
|
||||||
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
|
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
saver = tf.train.Saver(max_to_keep=10, var_list=var_list)
|
saver = tf.train.Saver(max_to_keep=10, var_list=var_list)
|
||||||
@ -105,3 +105,6 @@ def forward(board):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("No model loaded")
|
raise ValueError("No model loaded")
|
||||||
return sess.run([p,v], feed_dict={x:board})
|
return sess.run([p,v], feed_dict={x:board})
|
||||||
|
|
||||||
|
if __name__='main':
|
||||||
|
train()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user