connect our own network to mcts
This commit is contained in:
parent
ab1b3775e7
commit
d9368c9a78
2
AlphaGo/.gitignore
vendored
Normal file
2
AlphaGo/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
data
|
||||||
|
checkpoints
|
@ -1,14 +1,16 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow.contrib.layers as layers
|
import tensorflow.contrib.layers as layers
|
||||||
|
|
||||||
import multi_gpu
|
import multi_gpu
|
||||||
|
import time
|
||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
|
|
||||||
def residual_block(input, is_training):
|
def residual_block(input, is_training):
|
||||||
normalizer_params = {'is_training': is_training,
|
normalizer_params = {'is_training': is_training,
|
||||||
@ -127,34 +129,46 @@ def train():
|
|||||||
saver.save(sess, result_path + save_path)
|
saver.save(sess, result_path + save_path)
|
||||||
del data, boards, wins, ps
|
del data, boards, wins, ps
|
||||||
|
|
||||||
|
def forward(call_number):
|
||||||
|
#checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints"
|
||||||
|
checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints/jialian"
|
||||||
|
board_file = np.genfromtxt("/home/yama/rl/tianshou/leela-zero/src/mcts_nn_files/board_" + call_number, dtype='str');
|
||||||
|
human_board = np.zeros((17, 19, 19))
|
||||||
|
#TODO : is it ok to ignore the last channel?
|
||||||
|
for i in range(17):
|
||||||
|
human_board[i] = np.array(list(board_file[i])).reshape(19, 19)
|
||||||
|
feed_board = human_board.transpose(1, 2, 0).reshape(1, 19, 19, 17)
|
||||||
|
#print(feed_board.shape)
|
||||||
|
|
||||||
|
#npz_board = np.load("/home/yama/rl/tianshou/AlphaGo/data/7f83928932f64a79bc1efdea268698ae.npz")
|
||||||
|
#print(npz_board["boards"].shape)
|
||||||
|
#feed_board = npz_board["boards"][10].reshape(-1, 19, 19, 17)
|
||||||
|
##print(feed_board)
|
||||||
|
#show_board = feed_board[0].transpose(2, 0, 1)
|
||||||
|
#print("board shape : ", show_board.shape)
|
||||||
|
#print(show_board)
|
||||||
|
|
||||||
def forward(board):
|
|
||||||
result_path = "./results/"
|
|
||||||
itflag = False
|
|
||||||
res = None
|
|
||||||
if board is None:
|
|
||||||
board = np.load("/home/yama/tongzheng/AG/self_play_204/d7d7d552b7be4b51883de99d74a8e51b.npz")
|
|
||||||
board = board["boards"][100].reshape(-1, 19, 19, 17)
|
|
||||||
result_path = "../parameters/checkpoints"
|
|
||||||
itflag = True
|
itflag = True
|
||||||
with multi_gpu.create_session() as sess:
|
with multi_gpu.create_session() as sess:
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
ckpt_file = tf.train.latest_checkpoint(result_path)
|
ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
|
||||||
if ckpt_file is not None:
|
if ckpt_file is not None:
|
||||||
print('Restoring model from {}...'.format(ckpt_file))
|
#print('Restoring model from {}...'.format(ckpt_file))
|
||||||
saver.restore(sess, ckpt_file)
|
saver.restore(sess, ckpt_file)
|
||||||
else:
|
else:
|
||||||
raise ValueError("No model loaded")
|
raise ValueError("No model loaded")
|
||||||
res = sess.run([tf.nn.softmax(p), v], feed_dict={x: board, is_training: itflag})
|
res = sess.run([tf.nn.softmax(p),v], feed_dict={x:feed_board, is_training:itflag})
|
||||||
# res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
|
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
|
||||||
# res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
|
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
|
||||||
print(res)
|
#print(np.argmax(res[0]))
|
||||||
# print(res[0].tolist()[0])
|
np.savetxt(sys.stdout, res[0][0], fmt="%.6f", newline=" ")
|
||||||
# print(np.argmax(res[0]))
|
np.savetxt(sys.stdout, res[1][0], fmt="%.6f", newline=" ")
|
||||||
|
pv_file = "/home/yama/rl/tianshou/leela-zero/src/mcts_nn_files/policy_value"
|
||||||
|
np.savetxt(pv_file, np.concatenate((res[0][0], res[1][0])), fmt="%.6f", newline=" ")
|
||||||
|
#np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ")
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
if __name__=='__main__':
|
||||||
if __name__ == '__main__':
|
np.set_printoptions(threshold='nan')
|
||||||
train()
|
#time.sleep(2)
|
||||||
# if sys.argv[1] == "test":
|
forward(sys.argv[1])
|
||||||
# forward(None)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user