connect our own network to mcts
This commit is contained in:
parent
ab1b3775e7
commit
d9368c9a78
2
AlphaGo/.gitignore
vendored
Normal file
2
AlphaGo/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
data
|
||||
checkpoints
|
@ -1,14 +1,16 @@
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
import multi_gpu
|
||||
import time
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
|
||||
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
|
||||
def residual_block(input, is_training):
|
||||
normalizer_params = {'is_training': is_training,
|
||||
@ -127,34 +129,46 @@ def train():
|
||||
saver.save(sess, result_path + save_path)
|
||||
del data, boards, wins, ps
|
||||
|
||||
def forward(call_number):
|
||||
#checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints"
|
||||
checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints/jialian"
|
||||
board_file = np.genfromtxt("/home/yama/rl/tianshou/leela-zero/src/mcts_nn_files/board_" + call_number, dtype='str');
|
||||
human_board = np.zeros((17, 19, 19))
|
||||
#TODO : is it ok to ignore the last channel?
|
||||
for i in range(17):
|
||||
human_board[i] = np.array(list(board_file[i])).reshape(19, 19)
|
||||
feed_board = human_board.transpose(1, 2, 0).reshape(1, 19, 19, 17)
|
||||
#print(feed_board.shape)
|
||||
|
||||
def forward(board):
|
||||
result_path = "./results/"
|
||||
itflag = False
|
||||
res = None
|
||||
if board is None:
|
||||
board = np.load("/home/yama/tongzheng/AG/self_play_204/d7d7d552b7be4b51883de99d74a8e51b.npz")
|
||||
board = board["boards"][100].reshape(-1, 19, 19, 17)
|
||||
result_path = "../parameters/checkpoints"
|
||||
itflag = True
|
||||
with multi_gpu.create_session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
ckpt_file = tf.train.latest_checkpoint(result_path)
|
||||
if ckpt_file is not None:
|
||||
print('Restoring model from {}...'.format(ckpt_file))
|
||||
saver.restore(sess, ckpt_file)
|
||||
else:
|
||||
raise ValueError("No model loaded")
|
||||
res = sess.run([tf.nn.softmax(p), v], feed_dict={x: board, is_training: itflag})
|
||||
# res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
|
||||
# res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
|
||||
print(res)
|
||||
# print(res[0].tolist()[0])
|
||||
# print(np.argmax(res[0]))
|
||||
return res
|
||||
#npz_board = np.load("/home/yama/rl/tianshou/AlphaGo/data/7f83928932f64a79bc1efdea268698ae.npz")
|
||||
#print(npz_board["boards"].shape)
|
||||
#feed_board = npz_board["boards"][10].reshape(-1, 19, 19, 17)
|
||||
##print(feed_board)
|
||||
#show_board = feed_board[0].transpose(2, 0, 1)
|
||||
#print("board shape : ", show_board.shape)
|
||||
#print(show_board)
|
||||
|
||||
itflag = True
|
||||
with multi_gpu.create_session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
|
||||
if ckpt_file is not None:
|
||||
#print('Restoring model from {}...'.format(ckpt_file))
|
||||
saver.restore(sess, ckpt_file)
|
||||
else:
|
||||
raise ValueError("No model loaded")
|
||||
res = sess.run([tf.nn.softmax(p),v], feed_dict={x:feed_board, is_training:itflag})
|
||||
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
|
||||
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
|
||||
#print(np.argmax(res[0]))
|
||||
np.savetxt(sys.stdout, res[0][0], fmt="%.6f", newline=" ")
|
||||
np.savetxt(sys.stdout, res[1][0], fmt="%.6f", newline=" ")
|
||||
pv_file = "/home/yama/rl/tianshou/leela-zero/src/mcts_nn_files/policy_value"
|
||||
np.savetxt(pv_file, np.concatenate((res[0][0], res[1][0])), fmt="%.6f", newline=" ")
|
||||
#np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ")
|
||||
return res
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
# if sys.argv[1] == "test":
|
||||
# forward(None)
|
||||
if __name__=='__main__':
|
||||
np.set_printoptions(threshold='nan')
|
||||
#time.sleep(2)
|
||||
forward(sys.argv[1])
|
||||
|
Loading…
x
Reference in New Issue
Block a user