minor modification
This commit is contained in:
parent
937ee9f1e3
commit
f09ebc2124
@ -50,7 +50,7 @@ p = policy_heads(h, is_training)
|
|||||||
loss = tf.reduce_mean(tf.square(z-v)) - tf.reduce_mean(tf.multiply(pi, tf.log(tf.nn.softmax(p, 1))))
|
loss = tf.reduce_mean(tf.square(z-v)) - tf.reduce_mean(tf.multiply(pi, tf.log(tf.nn.softmax(p, 1))))
|
||||||
reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
|
reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
|
||||||
total_loss = loss + reg
|
total_loss = loss + reg
|
||||||
train_op = tf.train.RMSPropOptimizer(1e-2).minimize(total_loss)
|
train_op = tf.train.RMSPropOptimizer(1e-4).minimize(total_loss)
|
||||||
|
|
||||||
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
|
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
saver = tf.train.Saver(max_to_keep=10, var_list=var_list)
|
saver = tf.train.Saver(max_to_keep=10, var_list=var_list)
|
||||||
@ -105,3 +105,6 @@ def forward(board):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("No model loaded")
|
raise ValueError("No model loaded")
|
||||||
return sess.run([p,v], feed_dict={x:board})
|
return sess.run([p,v], feed_dict={x:board})
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train()
|
BIN
AlphaGo/Network.pyc
Normal file
BIN
AlphaGo/Network.pyc
Normal file
Binary file not shown.
BIN
AlphaGo/multi_gpu.pyc
Normal file
BIN
AlphaGo/multi_gpu.pyc
Normal file
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user