From 9ad53de54f0ef28aea0df9de31c9d2c405186d15 Mon Sep 17 00:00:00 2001 From: rtz19970824 <1289226405@qq.com> Date: Thu, 21 Dec 2017 23:30:24 +0800 Subject: [PATCH] implement the training process --- .gitignore | 1 + AlphaGo/game.py | 2 +- AlphaGo/model.py | 106 ++++++++++++++++++++++++++++++++++++++++++----- AlphaGo/play.py | 28 ++++++++----- 4 files changed, 114 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index 36d134c..d697b92 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ checkpoints checkpoints_origin *.json .DS_Store +data diff --git a/AlphaGo/game.py b/AlphaGo/game.py index bf0d084..c342d0c 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -60,7 +60,7 @@ class Game: def think(self, latest_boards, color): mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], self.size ** 2 + 1, inverse=True) - mcts.search(max_step=1) + mcts.search(max_step=20) temp = 1 prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp) choice = np.random.choice(self.size ** 2 + 1, 1, p=prob).tolist()[0] diff --git a/AlphaGo/model.py b/AlphaGo/model.py index fab864e..41f3a47 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -2,6 +2,7 @@ import os import time import sys import cPickle +from collections import deque import numpy as np import tensorflow as tf @@ -71,6 +72,13 @@ def value_head(input, is_training): return h +class Data(object): + def __init__(self): + self.boards = [] + self.probs = [] + self.winner = 0 + + class ResNet(object): def __init__(self, board_size, action_num, history_length=1, residual_block_num=20, checkpoint_path=None): """ @@ -85,11 +93,18 @@ class ResNet(object): self.board_size = board_size self.action_num = action_num self.history_length = history_length + self.checkpoint_path = checkpoint_path self.x = tf.placeholder(tf.float32, shape=[None, self.board_size, self.board_size, 2 * self.history_length + 1]) self.is_training = tf.placeholder(tf.bool, shape=[]) self.z = tf.placeholder(tf.float32, shape=[None, 1]) self.pi = tf.placeholder(tf.float32, shape=[None, self.action_num]) - self._build_network(residual_block_num, checkpoint_path) + self._build_network(residual_block_num, self.checkpoint_path) + + # training hyper-parameters: + self.window_length = 1000 + self.save_freq = 1000 + self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length), + 'winner': deque(maxlen=self.window_length)} def _build_network(self, residual_block_num, checkpoint_path): """ @@ -118,7 +133,7 @@ class ResNet(object): with tf.control_dependencies(self.update_ops): self.train_op = tf.train.AdamOptimizer(1e-4).minimize(self.total_loss) self.var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) - self.saver = tf.train.Saver(max_to_keep=10, var_list=self.var_list) + self.saver = tf.train.Saver(var_list=self.var_list) self.sess = multi_gpu.create_session() self.sess.run(tf.global_variables_initializer()) if checkpoint_path is not None: @@ -166,21 +181,90 @@ class ResNet(object): state[0, :, :, 2 * self.history_length] = np.zeros([self.board_size, self.board_size]) return state - #TODO: design the interface between the environment and training + # TODO: design the interface between the environment and training def train(self, mode='memory', *args, **kwargs): if mode == 'memory': pass if mode == 'file': - self.train_with_file(data_path=kwargs['data_path'], checkpoint_path=kwargs['checkpoint_path']) + self._train_with_file(data_path=kwargs['data_path'], batch_size=kwargs['batch_size'], + checkpoint_path=kwargs['checkpoint_path']) - def train_with_file(self, data_path, checkpoint_path): + def _train_with_file(self, data_path, batch_size, checkpoint_path): + # check if the path is valid if not os.path.exists(data_path): raise ValueError("{} doesn't exist".format(data_path)) + self.checkpoint_path = checkpoint_path + if not os.path.exists(self.checkpoint_path): + os.mkdir(self.checkpoint_path) - file_list = os.listdir(data_path) - if file_list <= 50: - time.sleep(1) - else: - file_list.sort(key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir( - data_path + file) else 0) + new_file_list = [] + all_file_list = [] + training_data = {} + iters = 0 + while True: + new_file_list = list(set(os.listdir(data_path)).difference(all_file_list)) + all_file_list = os.listdir(data_path) + new_file_list.sort( + key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir(data_path + file) else 0) + if new_file_list: + for file in new_file_list: + states, probs, winner = self._file_to_training_data(data_path + file) + assert states.shape[0] == probs.shape[0] + assert states.shape[0] == winner.shape[0] + self.training_data['states'].append(states) + self.training_data['probs'].append(probs) + self.training_data['winner'].append(winner) + training_data['states'] = np.concatenate(self.training_data['states'], axis=0) + training_data['probs'] = np.concatenate(self.training_data['probs'], axis=0) + training_data['winner'] = np.concatenate(self.training_data['winner'], axis=0) + if len(self.training_data['states']) != self.window_length: + continue + else: + data_num = training_data['states'].shape[0] + index = np.arange(data_num) + np.random.shuffle(index) + start_time = time.time() + value_loss, policy_loss, reg, _ = self.sess.run( + [self.value_loss, self.policy_loss, self.reg, self.train_op], + feed_dict={self.x: training_data['states'][index[:batch_size]], + self.z: training_data['winner'][index[:batch_size]], + self.pi: training_data['probs'][index[:batch_size]], + self.is_training: True}) + print("Iteration: {}, Time: {}, Value Loss: {}, Policy Loss: {}, Reg: {}".format(iters, + time.time() - start_time, + value_loss, + policy_loss, reg)) + iters += 1 + if iters % self.save_freq == 0: + save_path = "Iteration{}.ckpt".format(iters) + self.saver.save(self.sess, self.checkpoint_path + save_path) + + def _file_to_training_data(self, file_name): + with open(file_name, 'r') as file: + data = cPickle.load(file) + history = deque(maxlen=self.history_length) + states = [] + probs = [] + winner = [] + for _ in range(self.history_length): + # Note that 0 is specified, need a more general way like config + history.append([0] * self.board_size ** 2) + # Still, +1 is specified + color = +1 + + for [board, prob] in zip(data.boards, data.probs): + history.append(board) + states.append(self._history2state(history, color)) + probs.append(np.array(prob).reshape(1, self.board_size ** 2 + 1)) + winner.append(np.array(data.winner).reshape(1, 1)) + color *= -1 + states = np.concatenate(states, axis=0) + probs = np.concatenate(probs, axis=0) + winner = np.concatenate(winner, axis=0) + return states, probs, winner + + +if __name__=="__main__": + model = ResNet(board_size=9, action_num=82) + model.train("file", data_path="./data/", batch_size=128, checkpoint_path="./checkpoint/") \ No newline at end of file diff --git a/AlphaGo/play.py b/AlphaGo/play.py index 562dd14..bd3776e 100644 --- a/AlphaGo/play.py +++ b/AlphaGo/play.py @@ -76,6 +76,7 @@ if __name__ == '__main__': color = ['b', 'w'] pattern = "[A-Z]{1}[0-9]{1}" + space = re.compile("\s+") size = 9 show = ['.', 'X', 'O'] @@ -83,12 +84,20 @@ if __name__ == '__main__': game_num = 0 try: while True: + start_time = time.time() num = 0 pass_flag = [False, False] print("Start game {}".format(game_num)) # end the game if both palyer chose to pass, or play too much turns while not (pass_flag[0] and pass_flag[1]) and num < size ** 2 * 2: turn = num % 2 + board = player[turn].run_cmd(str(num) + ' show_board') + board = eval(board[board.index('['):board.index(']') + 1]) + for i in range(size): + for j in range(size): + print show[board[i * size + j]] + " ", + print "\n", + data.boards.append(board) move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn] + '\n') print role[turn] + " : " + str(move), num += 1 @@ -102,21 +111,18 @@ if __name__ == '__main__': play_or_pass = ' PASS' pass_flag[turn] = True result = player[1 - turn].run_cmd(str(num) + ' play ' + color[turn] + ' ' + play_or_pass + '\n') - board = player[turn].run_cmd(str(num) + ' show_board') - board = eval(board[board.index('['):board.index(']') + 1]) - for i in range(size): - for j in range(size): - print show[board[i * size + j]] + " ", - print "\n", - data.boards.append(board) prob = player[turn].run_cmd(str(num) + ' get_prob') + prob = space.sub(',', prob[prob.index('['):prob.index(']') + 1]) + prob = prob.replace('[,', '[') + prob = prob.replace('],', ']') + prob = eval(prob) data.probs.append(prob) score = player[turn].run_cmd(str(num) + ' get_score') print "Finished : ", score.split(" ")[1] # TODO: generalize the player - if score > 0: + if eval(score.split(" ")[1]) > 0: data.winner = 1 - if score < 0: + if eval(score.split(" ")[1]) < 0: data.winner = -1 player[0].run_cmd(str(num) + ' clear_board') player[1].run_cmd(str(num) + ' clear_board') @@ -127,12 +133,12 @@ if __name__ == '__main__': file_list.sort(key=lambda file: os.path.getmtime(args.result_path + file) if not os.path.isdir( args.result_path + file) else 0) data_num = eval(file_list[-1][:-4]) + 1 - print(file_list) with open("./data/" + str(data_num) + ".pkl", "w") as file: picklestring = cPickle.dump(data, file) data.reset() game_num += 1 - except KeyboardInterrupt: + print("Time {}".format(time.time()-start_time)) + except Exception: subprocess.call(["kill", "-9", str(agent_v0.pid)]) subprocess.call(["kill", "-9", str(agent_v1.pid)]) print "Kill all player, finish all game."