connect our own network to mcts

This commit is contained in:
Dong Yan 2017-11-12 22:40:58 +08:00
parent ab1b3775e7
commit d9368c9a78
2 changed files with 46 additions and 30 deletions

2
AlphaGo/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
data
checkpoints

View File

@ -1,14 +1,16 @@
import os import os
import time import time
import sys
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow.contrib.layers as layers import tensorflow.contrib.layers as layers
import multi_gpu import multi_gpu
import time
os.environ["CUDA_VISIBLE_DEVICES"] = "1" #os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
def residual_block(input, is_training): def residual_block(input, is_training):
normalizer_params = {'is_training': is_training, normalizer_params = {'is_training': is_training,
@ -127,34 +129,46 @@ def train():
saver.save(sess, result_path + save_path) saver.save(sess, result_path + save_path)
del data, boards, wins, ps del data, boards, wins, ps
def forward(call_number):
#checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints"
checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints/jialian"
board_file = np.genfromtxt("/home/yama/rl/tianshou/leela-zero/src/mcts_nn_files/board_" + call_number, dtype='str');
human_board = np.zeros((17, 19, 19))
#TODO : is it ok to ignore the last channel?
for i in range(17):
human_board[i] = np.array(list(board_file[i])).reshape(19, 19)
feed_board = human_board.transpose(1, 2, 0).reshape(1, 19, 19, 17)
#print(feed_board.shape)
#npz_board = np.load("/home/yama/rl/tianshou/AlphaGo/data/7f83928932f64a79bc1efdea268698ae.npz")
#print(npz_board["boards"].shape)
#feed_board = npz_board["boards"][10].reshape(-1, 19, 19, 17)
##print(feed_board)
#show_board = feed_board[0].transpose(2, 0, 1)
#print("board shape : ", show_board.shape)
#print(show_board)
def forward(board):
result_path = "./results/"
itflag = False
res = None
if board is None:
board = np.load("/home/yama/tongzheng/AG/self_play_204/d7d7d552b7be4b51883de99d74a8e51b.npz")
board = board["boards"][100].reshape(-1, 19, 19, 17)
result_path = "../parameters/checkpoints"
itflag = True itflag = True
with multi_gpu.create_session() as sess: with multi_gpu.create_session() as sess:
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
ckpt_file = tf.train.latest_checkpoint(result_path) ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
if ckpt_file is not None: if ckpt_file is not None:
print('Restoring model from {}...'.format(ckpt_file)) #print('Restoring model from {}...'.format(ckpt_file))
saver.restore(sess, ckpt_file) saver.restore(sess, ckpt_file)
else: else:
raise ValueError("No model loaded") raise ValueError("No model loaded")
res = sess.run([tf.nn.softmax(p), v], feed_dict={x: board, is_training: itflag}) res = sess.run([tf.nn.softmax(p),v], feed_dict={x:feed_board, is_training:itflag})
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False}) #res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True}) #res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
print(res)
# print(res[0].tolist()[0])
#print(np.argmax(res[0])) #print(np.argmax(res[0]))
np.savetxt(sys.stdout, res[0][0], fmt="%.6f", newline=" ")
np.savetxt(sys.stdout, res[1][0], fmt="%.6f", newline=" ")
pv_file = "/home/yama/rl/tianshou/leela-zero/src/mcts_nn_files/policy_value"
np.savetxt(pv_file, np.concatenate((res[0][0], res[1][0])), fmt="%.6f", newline=" ")
#np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ")
return res return res
if __name__=='__main__': if __name__=='__main__':
train() np.set_printoptions(threshold='nan')
# if sys.argv[1] == "test": #time.sleep(2)
# forward(None) forward(sys.argv[1])