From 2e8662889f366521c14de0c339b64e8595e393e2 Mon Sep 17 00:00:00 2001 From: rtz19970824 Date: Sat, 13 Jan 2018 15:57:41 +0800 Subject: [PATCH] add multi-thread for end-to-end training --- AlphaGo/model.py | 17 +++--- AlphaGo/play.py | 137 +++++++++++++++++++++++++---------------------- 2 files changed, 81 insertions(+), 73 deletions(-) diff --git a/AlphaGo/model.py b/AlphaGo/model.py index 5037173..404da4a 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -119,13 +119,12 @@ class ResNet(object): zip(self.black_var_list, self.white_var_list)] # training hyper-parameters: - self.window_length = 900 + self.window_length = 500 self.save_freq = 5000 self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length), 'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)} - # training or not - self.training = False + self.use_latest = False def _build_network(self, scope, residual_block_num): """ @@ -188,16 +187,16 @@ class ResNet(object): feed_dict={self.x: eval_state, self.is_training: False}) def check_latest_model(self): - if self.training: + if self.use_latest: black_ckpt_file = tf.train.latest_checkpoint(self.save_path + "black/") - if self.black_ckpt_file != black_ckpt_file: + if self.black_ckpt_file != black_ckpt_file and black_ckpt_file is not None: self.black_ckpt_file = black_ckpt_file print('Loading model from {}...'.format(self.black_ckpt_file)) self.black_saver.restore(self.sess, self.black_ckpt_file) print('Black Model Updated!') white_ckpt_file = tf.train.latest_checkpoint(self.save_path + "white/") - if self.white_ckpt_file != white_ckpt_file: + if self.white_ckpt_file != white_ckpt_file and white_ckpt_file is not None: self.white_ckpt_file = white_ckpt_file print('Loading model from {}...'.format(self.white_ckpt_file)) self.white_saver.restore(self.sess, self.white_ckpt_file) @@ -234,7 +233,7 @@ class ResNet(object): :param target: a string, which to optimize, can only be "both", "black" and "white" :param mode: a string, how to optimize, can only be "memory" and "file" """ - self.training = True + self.use_latest = True if mode == 'memory': pass if mode == 'file': @@ -401,5 +400,5 @@ class ResNet(object): if __name__ == "__main__": - model = ResNet(board_size=8, action_num=65, history_length=1, black_checkpoint_path="./checkpoint/black", white_checkpoint_path="./checkpoint/white") - model.train(mode="file", data_path="./data/", batch_size=128, save_path="./checkpoint/") + model = ResNet(board_size=9, action_num=82, history_length=8, black_checkpoint_path="./checkpoint/black", white_checkpoint_path="./checkpoint/white") + model.train(mode="file", data_path="./data/", batch_size=128, save_path="./go-v2/") diff --git a/AlphaGo/play.py b/AlphaGo/play.py index a000b78..c677c04 100644 --- a/AlphaGo/play.py +++ b/AlphaGo/play.py @@ -3,6 +3,7 @@ import sys import re import time import os +import threading from game import Game from engine import GTPEngine from utils import Data @@ -17,6 +18,67 @@ else: import _pickle as cPickle +def play(engine, data_path): + data = Data() + role = ["BLACK", "WHITE"] + color = ['b', 'w'] + + pattern = "[A-Z]{1}[0-9]{1}" + space = re.compile("\s+") + size = {"go": 9, "reversi": 8} + show = ['.', 'X', 'O'] + + # evaluate_rounds = 100 + game_num = 0 + while True: + # while game_num < evaluate_rounds: + engine._game.model.check_latest_model() + num = 0 + pass_flag = [False, False] + print("Start game {}".format(game_num)) + # end the game if both palyer chose to pass, or play too much turns + while not (pass_flag[0] and pass_flag[1]) and num < size[engine._game.name] ** 2 * 2: + turn = num % 2 + board = engine.run_cmd(str(num) + ' show_board') + board = eval(board[board.index('['):board.index(']') + 1]) + for i in range(size[engine._game.name]): + for j in range(size[engine._game.name]): + print show[board[i * size[engine._game.name] + j]] + " ", + print "\n", + data.boards.append(board) + move = engine.run_cmd(str(num) + ' genmove ' + color[turn])[:-1] + print("\n" + role[turn] + " : " + str(move)), + num += 1 + match = re.search(pattern, move) + if match is not None: + # print "match : " + str(match.group()) + pass_flag[turn] = False + else: + # print "no match" + pass_flag[turn] = True + prob = engine.run_cmd(str(num) + ' get_prob') + prob = space.sub(',', prob[prob.index('['):prob.index(']') + 1]) + prob = prob.replace('[,', '[') + prob = prob.replace('],', ']') + prob = eval(prob) + data.probs.append(prob) + score = engine.run_cmd(str(num) + ' get_score') + print("Finished : {}".format(score.split(" ")[1])) + if eval(score.split(" ")[1]) > 0: + data.winner = utils.BLACK + if eval(score.split(" ")[1]) < 0: + data.winner = utils.WHITE + engine.run_cmd(str(num) + ' clear_board') + current_time = strftime("%Y%m%d_%H%M%S", gmtime()) + if os.path.exists(data_path + current_time + ".pkl"): + time.sleep(1) + current_time = strftime("%Y%m%d_%H%M%S", gmtime()) + with open(data_path + current_time + ".pkl", "wb") as file: + cPickle.dump(data, file) + data.reset() + game_num += 1 + + if __name__ == '__main__': """ Starting two different players which load network weights to evaluate the winning ratio. @@ -27,6 +89,7 @@ if __name__ == '__main__': parser.add_argument("--data_path", type=str, default="./data/") parser.add_argument("--black_weight_path", type=str, default=None) parser.add_argument("--white_weight_path", type=str, default=None) + parser.add_argument("--save_path", type=str, default="./go/") parser.add_argument("--debug", type=bool, default=False) parser.add_argument("--game", type=str, default="go") args = parser.parse_args() @@ -46,69 +109,15 @@ if __name__ == '__main__': debug=args.debug) engine = GTPEngine(game_obj=game, name='tianshou', version=0) - data = Data() - role = ["BLACK", "WHITE"] - color = ['b', 'w'] + thread_list = [] + thread_train = threading.Thread(target=game.model.train, args=("file",), + kwargs={'data_path':args.data_path, 'batch_size':128, 'save_path':args.save_path}) + thread_play = threading.Thread(target=play, args=(engine, args.data_path)) + thread_list.append(thread_train) + thread_list.append(thread_play) - pattern = "[A-Z]{1}[0-9]{1}" - space = re.compile("\s+") - size = {"go":9, "reversi":8} - show = ['.', 'X', 'O'] + for t in thread_list: + t.start() - evaluate_rounds = 100 - game_num = 0 - try: - while True: - #while game_num < evaluate_rounds: - start_time = time.time() - game.model.check_latest_model() - num = 0 - pass_flag = [False, False] - print("Start game {}".format(game_num)) - # end the game if both palyer chose to pass, or play too much turns - while not (pass_flag[0] and pass_flag[1]) and num < size[args.game] ** 2 * 2: - turn = num % 2 - board = engine.run_cmd(str(num) + ' show_board') - board = eval(board[board.index('['):board.index(']') + 1]) - for i in range(size[args.game]): - for j in range(size[args.game]): - print show[board[i * size[args.game] + j]] + " ", - print "\n", - data.boards.append(board) - start_time = time.time() - move = engine.run_cmd(str(num) + ' genmove ' + color[turn])[:-1] - print("\n" + role[turn] + " : " + str(move)), - num += 1 - match = re.search(pattern, move) - if match is not None: - # print "match : " + str(match.group()) - play_or_pass = match.group() - pass_flag[turn] = False - else: - # print "no match" - play_or_pass = ' PASS' - pass_flag[turn] = True - prob = engine.run_cmd(str(num) + ' get_prob') - prob = space.sub(',', prob[prob.index('['):prob.index(']') + 1]) - prob = prob.replace('[,', '[') - prob = prob.replace('],', ']') - prob = eval(prob) - data.probs.append(prob) - score = engine.run_cmd(str(num) + ' get_score') - print("Finished : {}".format(score.split(" ")[1])) - if eval(score.split(" ")[1]) > 0: - data.winner = utils.BLACK - if eval(score.split(" ")[1]) < 0: - data.winner = utils.WHITE - engine.run_cmd(str(num) + ' clear_board') - file_list = os.listdir(args.data_path) - current_time = strftime("%Y%m%d_%H%M%S", gmtime()) - if os.path.exists(args.data_path + current_time + ".pkl"): - time.sleep(1) - current_time = strftime("%Y%m%d_%H%M%S", gmtime()) - with open(args.data_path + current_time + ".pkl", "wb") as file: - picklestring = cPickle.dump(data, file) - data.reset() - game_num += 1 - except KeyboardInterrupt: - pass + for t in thread_list: + t.join()