From 543d876f129ac353f73bd546afea7536c207b4a8 Mon Sep 17 00:00:00 2001 From: rtz19970824 Date: Mon, 4 Dec 2017 11:01:49 +0800 Subject: [PATCH] merge gtp --- AlphaGo/engine.py | 2 +- AlphaGo/game.py | 14 ++- AlphaGo/network_small.py | 211 +++++++++++++++++++++++++++++++++++++ AlphaGo/strategy.py | 30 +++--- AlphaGo/test.py | 165 ++++------------------------- tianshou/core/mcts/mcts.py | 1 - 6 files changed, 251 insertions(+), 172 deletions(-) create mode 100644 AlphaGo/network_small.py diff --git a/AlphaGo/engine.py b/AlphaGo/engine.py index b55a8d5..fef194f 100644 --- a/AlphaGo/engine.py +++ b/AlphaGo/engine.py @@ -11,7 +11,7 @@ import utils class GTPEngine(): def __init__(self, **kwargs): - self.size = 19 + self.size = 9 self.komi = 6.5 try: self._game = kwargs['game_obj'] diff --git a/AlphaGo/game.py b/AlphaGo/game.py index 88e5d17..3ef0918 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -148,10 +148,7 @@ class Executor: def _find_empty(self): idx = [i for i,x in enumerate(self.game.board) if x == utils.EMPTY ][0] return self.game._deflatten(idx) - - - def get_score(self): ''' return score from BLACK perspective. @@ -182,12 +179,12 @@ class Executor: class Game: - def __init__(self, size=19, komi=6.5): + def __init__(self, size=9, komi=6.5): self.size = size - self.komi = 6.5 + self.komi = komi self.board = [utils.EMPTY] * (self.size * self.size) - #self.strategy = strategy() - self.strategy = None + self.strategy = strategy() + # self.strategy = None self.executor = Executor(game = self) self.history = [] self.past = deque(maxlen=8) @@ -227,11 +224,12 @@ class Game: # move = self.strategy.gen_move(color) # return move move = self.strategy.gen_move(self.past, color) + print(move) self.do_move(color, move) return move def status2symbol(self, s): - pool = {utils.WHITE: '#', utils.EMPTY: '.', utils.BLACK: '*', utils.FILL: 'F', utils.UNKNOWN: '?'} + pool = {utils.WHITE: 'O', utils.EMPTY: '.', utils.BLACK: 'X', utils.FILL: 'F', utils.UNKNOWN: '?'} return pool[s] def show_board(self): diff --git a/AlphaGo/network_small.py b/AlphaGo/network_small.py new file mode 100644 index 0000000..8dd5140 --- /dev/null +++ b/AlphaGo/network_small.py @@ -0,0 +1,211 @@ +import os +import time +import sys + +import numpy as np +import time +import tensorflow as tf +import tensorflow.contrib.layers as layers + +import multi_gpu +import time + +# os.environ["CUDA_VISIBLE_DEVICES"] = "1" +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + + +def residual_block(input, is_training): + normalizer_params = {'is_training': is_training, + 'updates_collections': tf.GraphKeys.UPDATE_OPS} + h = layers.conv2d(input, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu, + normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params, + weights_regularizer=layers.l2_regularizer(1e-4)) + h = layers.conv2d(h, 256, kernel_size=3, stride=1, activation_fn=tf.identity, + normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params, + weights_regularizer=layers.l2_regularizer(1e-4)) + h = h + input + return tf.nn.relu(h) + + +def policy_heads(input, is_training): + normalizer_params = {'is_training': is_training, + 'updates_collections': tf.GraphKeys.UPDATE_OPS} + h = layers.conv2d(input, 2, kernel_size=1, stride=1, activation_fn=tf.nn.relu, + normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params, + weights_regularizer=layers.l2_regularizer(1e-4)) + h = layers.flatten(h) + h = layers.fully_connected(h, 82, activation_fn=tf.identity, weights_regularizer=layers.l2_regularizer(1e-4)) + return h + + +def value_heads(input, is_training): + normalizer_params = {'is_training': is_training, + 'updates_collections': tf.GraphKeys.UPDATE_OPS} + h = layers.conv2d(input, 2, kernel_size=1, stride=1, activation_fn=tf.nn.relu, + normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params, + weights_regularizer=layers.l2_regularizer(1e-4)) + h = layers.flatten(h) + h = layers.fully_connected(h, 256, activation_fn=tf.nn.relu, weights_regularizer=layers.l2_regularizer(1e-4)) + h = layers.fully_connected(h, 1, activation_fn=tf.nn.tanh, weights_regularizer=layers.l2_regularizer(1e-4)) + return h + + +class Network(object): + def __init__(self): + self.x = tf.placeholder(tf.float32, shape=[None, 9, 9, 17]) + self.is_training = tf.placeholder(tf.bool, shape=[]) + self.z = tf.placeholder(tf.float32, shape=[None, 1]) + self.pi = tf.placeholder(tf.float32, shape=[None, 82]) + self.build_network() + + def build_network(self): + h = layers.conv2d(self.x, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm, + normalizer_params={'is_training': self.is_training, + 'updates_collections': tf.GraphKeys.UPDATE_OPS}, + weights_regularizer=layers.l2_regularizer(1e-4)) + for i in range(19): + h = residual_block(h, self.is_training) + self.v = value_heads(h, self.is_training) + self.p = policy_heads(h, self.is_training) + # loss = tf.reduce_mean(tf.square(z-v)) - tf.multiply(pi, tf.log(tf.clip_by_value(tf.nn.softmax(p), 1e-8, tf.reduce_max(tf.nn.softmax(p))))) + self.value_loss = tf.reduce_mean(tf.square(self.z - self.v)) + self.policy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.pi, logits=self.p)) + + self.reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + self.total_loss = self.value_loss + self.policy_loss + self.reg + # train_op = tf.train.MomentumOptimizer(1e-4, momentum=0.9, use_nesterov=True).minimize(total_loss) + self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + with tf.control_dependencies(self.update_ops): + self.train_op = tf.train.RMSPropOptimizer(1e-4).minimize(self.total_loss) + self.var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + self.saver = tf.train.Saver(max_to_keep=10, var_list=self.var_list) + + def train(self): + data_path = "/home/tongzheng/data/" + data_name = os.listdir("/home/tongzheng/data/") + epochs = 100 + batch_size = 128 + + result_path = "./checkpoints/" + with multi_gpu.create_session() as sess: + sess.run(tf.global_variables_initializer()) + ckpt_file = tf.train.latest_checkpoint(result_path) + if ckpt_file is not None: + print('Restoring model from {}...'.format(ckpt_file)) + self.saver.restore(sess, ckpt_file) + for epoch in range(epochs): + for name in data_name: + data = np.load(data_path + name) + boards = data["boards"] + wins = data["wins"] + ps = data["ps"] + print (boards.shape) + print (wins.shape) + print (ps.shape) + batch_num = boards.shape[0] // batch_size + index = np.arange(boards.shape[0]) + np.random.shuffle(index) + value_losses = [] + policy_losses = [] + regs = [] + time_train = -time.time() + for iter in range(batch_num): + lv, lp, r, value, prob, _ = sess.run( + [self.value_loss, self.policy_loss, self.reg, self.v, tf.nn.softmax(p), self.train_op], + feed_dict={self.x: boards[ + index[iter * batch_size:(iter + 1) * batch_size]], + self.z: wins[index[ + iter * batch_size:(iter + 1) * batch_size]], + self.pi: ps[index[ + iter * batch_size:(iter + 1) * batch_size]], + self.is_training: True}) + value_losses.append(lv) + policy_losses.append(lp) + regs.append(r) + if iter % 1 == 0: + print( + "Epoch: {}, Part {}, Iteration: {}, Time: {}, Value Loss: {}, Policy Loss: {}, Reg: {}".format( + epoch, name, iter, time.time() + time_train, np.mean(np.array(value_losses)), + np.mean(np.array(policy_losses)), np.mean(np.array(regs)))) + time_train = -time.time() + value_losses = [] + policy_losses = [] + regs = [] + if iter % 20 == 0: + save_path = "Epoch{}.Part{}.Iteration{}.ckpt".format(epoch, name, iter) + self.saver.save(sess, result_path + save_path) + del data, boards, wins, ps + + + # def forward(call_number): + # # checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints" + # checkpoint_path = "/home/jialian/stuGo/tianshou/stuGo/checkpoints/" + # board_file = np.genfromtxt("/home/jialian/stuGo/tianshou/leela-zero/src/mcts_nn_files/board_" + call_number, + # dtype='str'); + # human_board = np.zeros((17, 19, 19)) + # + # # TODO : is it ok to ignore the last channel? + # for i in range(17): + # human_board[i] = np.array(list(board_file[i])).reshape(19, 19) + # # print("============================") + # # print("human board sum : " + str(np.sum(human_board[-1]))) + # # print("============================") + # # print(human_board) + # # print("============================") + # # rint(human_board) + # feed_board = human_board.transpose(1, 2, 0).reshape(1, 19, 19, 17) + # # print(feed_board[:,:,:,-1]) + # # print(feed_board.shape) + # + # # npz_board = np.load("/home/yama/rl/tianshou/AlphaGo/data/7f83928932f64a79bc1efdea268698ae.npz") + # # print(npz_board["boards"].shape) + # # feed_board = npz_board["boards"][10].reshape(-1, 19, 19, 17) + # ##print(feed_board) + # # show_board = feed_board[0].transpose(2, 0, 1) + # # print("board shape : ", show_board.shape) + # # print(show_board) + # + # itflag = False + # with multi_gpu.create_session() as sess: + # sess.run(tf.global_variables_initializer()) + # ckpt_file = tf.train.latest_checkpoint(checkpoint_path) + # if ckpt_file is not None: + # # print('Restoring model from {}...'.format(ckpt_file)) + # saver.restore(sess, ckpt_file) + # else: + # raise ValueError("No model loaded") + # res = sess.run([tf.nn.softmax(p), v], feed_dict={x: feed_board, is_training: itflag}) + # # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False}) + # # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True}) + # # print(np.argmax(res[0])) + # np.savetxt(sys.stdout, res[0][0], fmt="%.6f", newline=" ") + # np.savetxt(sys.stdout, res[1][0], fmt="%.6f", newline=" ") + # pv_file = "/home/jialian/stuGotianshou/leela-zero/src/mcts_nn_files/policy_value" + # np.savetxt(pv_file, np.concatenate((res[0][0], res[1][0])), fmt="%.6f", newline=" ") + # # np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ") + # return res + + def forward(self): + # checkpoint_path = "/home/tongzheng/tianshou/AlphaGo/checkpoints/" + sess = multi_gpu.create_session() + sess.run(tf.global_variables_initializer()) + # ckpt_file = tf.train.latest_checkpoint(checkpoint_path) + # if ckpt_file is not None: + # print('Restoring model from {}...'.format(ckpt_file)) + # self.saver.restore(sess, ckpt_file) + # print('Successfully loaded') + # else: + # raise ValueError("No model loaded") + # prior, value = sess.run([tf.nn.softmax(p), v], feed_dict={x: state, is_training: False}) + # return prior, value + return sess + + +if __name__ == '__main__': + state = np.random.randint(0, 1, [1, 9, 9, 17]) + net = Network() + sess = net.forward() + start = time.time() + for i in range(100): + sess.run([tf.nn.softmax(net.p), net.v], feed_dict={net.x: state, net.is_training: False}) + print("Step {}, Cumulative time {}".format(i, time.time() - start)) diff --git a/AlphaGo/strategy.py b/AlphaGo/strategy.py index f2b2bb9..d450a00 100644 --- a/AlphaGo/strategy.py +++ b/AlphaGo/strategy.py @@ -3,14 +3,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir)) import numpy as np import utils import time -import Network +import network_small import tensorflow as tf from collections import deque from tianshou.core.mcts.mcts import MCTS class GoEnv: - def __init__(self, size=19, komi=6.5): + def __init__(self, size=9, komi=6.5): self.size = size self.komi = 6.5 self.board = [utils.EMPTY] * (self.size * self.size) @@ -138,15 +138,15 @@ class GoEnv: color = 1 else: color = -1 - if action == 361: + if action == 81: vertex = (0, 0) else: - vertex = (action / 19 + 1, action % 19) + vertex = (action / 9 + 1, action % 9) self.do_move(color, vertex) new_state = np.concatenate( - [state[:, :, :, 1:8], (np.array(self.board) == 1).reshape(1, 19, 19, 1), - state[:, :, :, 9:16], (np.array(self.board) == -1).reshape(1, 19, 19, 1), - np.array(1 - state[:, :, :, -1]).reshape(1, 19, 19, 1)], + [state[:, :, :, 1:8], (np.array(self.board) == 1).reshape(1, 9, 9, 1), + state[:, :, :, 9:16], (np.array(self.board) == -1).reshape(1, 9, 9, 1), + np.array(1 - state[:, :, :, -1]).reshape(1, 9, 9, 1)], axis=3) return new_state, 0 @@ -154,20 +154,20 @@ class GoEnv: class strategy(object): def __init__(self): self.simulator = GoEnv() - self.net = Network.Network() + self.net = network_small.Network() self.sess = self.net.forward() self.evaluator = lambda state: self.sess.run([tf.nn.softmax(self.net.p), self.net.v], feed_dict={self.net.x: state, self.net.is_training: False}) def data_process(self, history, color): - state = np.zeros([1, 19, 19, 17]) + state = np.zeros([1, 9, 9, 17]) for i in range(8): state[0, :, :, i] = history[i] == 1 state[0, :, :, i + 8] = history[i] == -1 if color == 1: - state[0, :, :, 16] = np.ones([19, 19]) + state[0, :, :, 16] = np.ones([9, 9]) if color == -1: - state[0, :, :, 16] = np.zeros([19, 19]) + state[0, :, :, 16] = np.zeros([9, 9]) return state def gen_move(self, history, color): @@ -175,12 +175,12 @@ class strategy(object): self.simulator.board = history[-1] state = self.data_process(history, color) prior = self.evaluator(state)[0] - mcts = MCTS(self.simulator, self.evaluator, state, 362, prior, inverse=True, max_step=100) + mcts = MCTS(self.simulator, self.evaluator, state, 82, prior, inverse=True, max_step=100) temp = 1 p = mcts.root.N ** temp / np.sum(mcts.root.N ** temp) - choice = np.random.choice(362, 1, p=p).tolist()[0] - if choice == 361: + choice = np.random.choice(82, 1, p=p).tolist()[0] + if choice == 81: move = (0, 0) else: - move = (choice / 19 + 1, choice % 19 + 1) + move = (choice / 9 + 1, choice % 9 + 1) return move diff --git a/AlphaGo/test.py b/AlphaGo/test.py index 66bf131..ce74bbd 100644 --- a/AlphaGo/test.py +++ b/AlphaGo/test.py @@ -11,153 +11,24 @@ import utils g = Game() e = GTPEngine(game_obj=g) -res = e.run_cmd('1 protocol_version') -print(e.known_commands) -print(res) - -#res = e.run_cmd('2 name') -#print(res) - -#res = e.run_cmd('3 known_command quit') -#print(res) - -#res = e.run_cmd('4 unknown_command quitagain') -#print(res) - -#res = e.run_cmd('5 list_commands') -#print(res) - -#res = e.run_cmd('6 komi 6') -#print(res) - -#res = e.run_cmd('7 play BLACK C3') -#print(res) - -# res = e.run_cmd('play BLACK C4') -# res = e.run_cmd('play BLACK C5') -# res = e.run_cmd('play BLACK C6') -# res = e.run_cmd('play BLACK D3') -# print(res) - - -#res = e.run_cmd('8 genmove WHITE') -#print(res) -#g.show_board() - -# res = e.run_cmd('8 genmove BLACK') -# print(res) -# g.show_board() -# -# res = e.run_cmd('8 genmove WHITE') -# print(res) -# g.show_board() -# -# res = e.run_cmd('8 genmove BLACK') -# print(res) -# g.show_board() -# -# res = e.run_cmd('8 genmove WHITE') -# print(res) -# g.show_board() -# #g.show_board() -# print(g.check_valid((10, 9))) -# print(g.executor._neighbor((1,1))) -# print(g.do_move(utils.WHITE, (4, 6))) -# #g.show_board() -# -# -# res = e.run_cmd('play BLACK L10') -# res = e.run_cmd('play BLACK L11') -# res = e.run_cmd('play BLACK L12') -# res = e.run_cmd('play BLACK L13') -# res = e.run_cmd('play BLACK L14') -# res = e.run_cmd('play BLACK m15') -# res = e.run_cmd('play BLACK m9') -# res = e.run_cmd('play BLACK C9') -# res = e.run_cmd('play BLACK D9') -# res = e.run_cmd('play BLACK E9') -# res = e.run_cmd('play BLACK F9') -# res = e.run_cmd('play BLACK G9') -# res = e.run_cmd('play BLACK H9') -# res = e.run_cmd('play BLACK I9') -# -# res = e.run_cmd('play BLACK N9') -# res = e.run_cmd('play BLACK N15') -# res = e.run_cmd('play BLACK O10') -# res = e.run_cmd('play BLACK O11') -# res = e.run_cmd('play BLACK O12') -# res = e.run_cmd('play BLACK O13') -# res = e.run_cmd('play BLACK O14') -# res = e.run_cmd('play BLACK M12') -# -# res = e.run_cmd('play WHITE M10') -# res = e.run_cmd('play WHITE M11') -# res = e.run_cmd('play WHITE N10') -# res = e.run_cmd('play WHITE N11') -# -# res = e.run_cmd('play WHITE M13') -# res = e.run_cmd('play WHITE M14') -# res = e.run_cmd('play WHITE N13') -# res = e.run_cmd('play WHITE N14') -# print(res) -# -# res = e.run_cmd('play BLACK N12') -# print(res) -# #g.show_board() -# -res = e.run_cmd('play BLACK P16') -res = e.run_cmd('play BLACK P17') -res = e.run_cmd('play BLACK P18') -res = e.run_cmd('play BLACK P19') -res = e.run_cmd('play BLACK Q16') -res = e.run_cmd('play BLACK R16') -res = e.run_cmd('play BLACK S16') - -res = e.run_cmd('play WHITE S18') -res = e.run_cmd('play WHITE S17') -res = e.run_cmd('play WHITE Q19') -res = e.run_cmd('play WHITE Q18') -res = e.run_cmd('play WHITE Q17') -res = e.run_cmd('play WHITE R18') -res = e.run_cmd('play WHITE R17') -res = e.run_cmd('play BLACK S19') -# print(res) -# #g.show_board() -# -res = e.run_cmd('play WHITE R19') -# g.show_board() -# -res = e.run_cmd('play BLACK S19') -# print(res) -# g.show_board() -# -res = e.run_cmd('play BLACK S19') -# print(res) -# -# -# res = e.run_cmd('play BLACK E17') -# res = e.run_cmd('play BLACK F16') -# res = e.run_cmd('play BLACK F18') -# res = e.run_cmd('play BLACK G17') -# res = e.run_cmd('play WHITE G16') -# res = e.run_cmd('play WHITE G18') -# res = e.run_cmd('play WHITE H17') -# g.show_board() -# -# res = e.run_cmd('play WHITE F17') -# g.show_board() -# -# res = e.run_cmd('play BLACK G17') -# print(res) -# g.show_board() -# -# res = e.run_cmd('play BLACK G19') -# res = e.run_cmd('play BLACK G17') +e.run_cmd("genmove BLACK") g.show_board() - -res = e.run_cmd('play WHITE S18') +e.run_cmd("genmove WHITE") +g.show_board() +e.run_cmd("genmove BLACK") +g.show_board() +e.run_cmd("genmove WHITE") +g.show_board() +e.run_cmd("genmove BLACK") +g.show_board() +e.run_cmd("genmove WHITE") +g.show_board() +e.run_cmd("genmove BLACK") +g.show_board() +e.run_cmd("genmove WHITE") +g.show_board() +e.run_cmd("genmove BLACK") +g.show_board() +e.run_cmd("genmove WHITE") g.show_board() - -res = g.executor.get_score() -print(res) diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index 37fc2a8..5a1cbac 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -160,7 +160,6 @@ class MCTS: def expand(self): node, new_action = self.root.selection(self.simulator) value = node.children[new_action].expansion(self.evaluator, self.action_num) - print("Value:{}".format(value)) node.children[new_action].backpropagation(value + 0.)