add test interface for Network

This commit is contained in:
Dong Yan 2017-11-06 23:13:11 +08:00
parent a8030c95f2
commit 2fc87f7020
2 changed files with 28 additions and 12 deletions

2
.gitignore vendored
View File

@ -1,4 +1,4 @@
.idea
leela-zero
.pyc
*.pyc
parameters

View File

@ -3,6 +3,7 @@ import numpy as np
import time
import multi_gpu
import tensorflow.contrib.layers as layers
import sys
def residual_block(input, is_training):
normalizer_params = {'is_training': is_training,
@ -96,6 +97,13 @@ def train():
def forward(board):
result_path = "./results/"
itflag = False
res = None
if board is None:
board = np.load("/home/yama/tongzheng/AG/self_play_204/d7d7d552b7be4b51883de99d74a8e51b.npz")
board = board["boards"][100].reshape(-1, 19, 19, 17)
result_path = "../parameters/checkpoints"
itflag = True
with multi_gpu.create_session() as sess:
sess.run(tf.global_variables_initializer())
ckpt_file = tf.train.latest_checkpoint(result_path)
@ -104,7 +112,15 @@ def forward(board):
saver.restore(sess, ckpt_file)
else:
raise ValueError("No model loaded")
return sess.run([p,v], feed_dict={x:board})
res = sess.run([tf.nn.softmax(p),v], feed_dict={x:board, is_training:itflag})
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
print(res)
#print(res[0].tolist()[0])
#print(np.argmax(res[0]))
return res
if __name__='main':
train()
if __name__=='__main__':
#train()
if sys.argv[1] == "test":
forward(None)