From eda7ed07a1b7b0251745981d71ab9f358f15944e Mon Sep 17 00:00:00 2001 From: rtz19970824 <1289226405@qq.com> Date: Thu, 21 Dec 2017 21:01:25 +0800 Subject: [PATCH] implement data collection and part of training --- AlphaGo/engine.py | 6 ++- AlphaGo/game.py | 19 +------- AlphaGo/model.py | 18 +++++++- AlphaGo/play.py | 115 ++++++++++++++++++++++++++++++---------------- AlphaGo/player.py | 1 + 5 files changed, 101 insertions(+), 58 deletions(-) diff --git a/AlphaGo/engine.py b/AlphaGo/engine.py index bf30083..c9f1a3c 100644 --- a/AlphaGo/engine.py +++ b/AlphaGo/engine.py @@ -183,11 +183,15 @@ class GTPEngine(): return 'unknown player', False def cmd_get_score(self, args, **kwargs): - return self._game.game_engine.executor_get_score(True), None + return self._game.game_engine.executor_get_score(True), True def cmd_show_board(self, args, **kwargs): return self._game.board, True + def cmd_get_prob(self, args, **kwargs): + return self._game.prob, True + + if __name__ == "main": game = Game() engine = GTPEngine(game_obj=Game) diff --git a/AlphaGo/game.py b/AlphaGo/game.py index 5f35c74..bf0d084 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -58,24 +58,9 @@ class Game: def set_komi(self, k): self.komi = k - def generate_nn_input(self, latest_boards, color): - state = np.zeros([1, self.size, self.size, 17]) - for i in range(8): - state[0, :, :, i] = np.array(np.array(latest_boards[i]) == np.ones(self.size ** 2)).reshape(self.size, self.size) - state[0, :, :, i + 8] = np.array(np.array(latest_boards[i]) == -np.ones(self.size ** 2)).reshape(self.size, self.size) - if color == utils.BLACK: - state[0, :, :, 16] = np.ones([self.size, self.size]) - if color == utils.WHITE: - state[0, :, :, 16] = np.zeros([self.size, self.size]) - return state - def think(self, latest_boards, color): - # TODO : using copy is right, or should we change to deepcopy? - self.game_engine.simulate_latest_boards = copy.copy(latest_boards) - self.game_engine.simulate_board = copy.copy(latest_boards[-1]) - nn_input = self.generate_nn_input(self.game_engine.simulate_latest_boards, color) - mcts = MCTS(self.game_engine, self.evaluator, [self.game_engine.simulate_latest_boards, color], self.size ** 2 + 1, inverse=True) - mcts.search(max_step=5) + mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], self.size ** 2 + 1, inverse=True) + mcts.search(max_step=1) temp = 1 prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp) choice = np.random.choice(self.size ** 2 + 1, 1, p=prob).tolist()[0] diff --git a/AlphaGo/model.py b/AlphaGo/model.py index 725dbd2..fab864e 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -1,6 +1,7 @@ import os import time import sys +import cPickle import numpy as np import tensorflow as tf @@ -167,4 +168,19 @@ class ResNet(object): #TODO: design the interface between the environment and training def train(self, mode='memory', *args, **kwargs): - pass \ No newline at end of file + if mode == 'memory': + pass + if mode == 'file': + self.train_with_file(data_path=kwargs['data_path'], checkpoint_path=kwargs['checkpoint_path']) + + def train_with_file(self, data_path, checkpoint_path): + if not os.path.exists(data_path): + raise ValueError("{} doesn't exist".format(data_path)) + + file_list = os.listdir(data_path) + if file_list <= 50: + time.sleep(1) + else: + file_list.sort(key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir( + data_path + file) else 0) + diff --git a/AlphaGo/play.py b/AlphaGo/play.py index 7367804..562dd14 100644 --- a/AlphaGo/play.py +++ b/AlphaGo/play.py @@ -5,6 +5,18 @@ import re import Pyro4 import time import os +import cPickle + + +class Data(object): + def __init__(self): + self.boards = [] + self.probs = [] + self.winner = 0 + + def reset(self): + self.__init__() + if __name__ == '__main__': """ @@ -13,10 +25,13 @@ if __name__ == '__main__': """ # TODO : we should set the network path in a more configurable way. parser = argparse.ArgumentParser() + parser.add_argument("--result_path", type=str, default="./data/") parser.add_argument("--black_weight_path", type=str, default=None) parser.add_argument("--white_weight_path", type=str, default=None) args = parser.parse_args() + if not os.path.exists(args.result_path): + os.mkdir(args.result_path) # black_weight_path = "./checkpoints" # white_weight_path = "./checkpoints_origin" if args.black_weight_path is not None and (not os.path.exists(args.black_weight_path)): @@ -35,11 +50,13 @@ if __name__ == '__main__': time.sleep(1) # start two different player with different network weights. - agent_v0 = subprocess.Popen(['python', '-u', 'player.py', '--role=black', '--checkpoint_path=' + str(args.black_weight_path)], - stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + agent_v0 = subprocess.Popen( + ['python', '-u', 'player.py', '--role=black', '--checkpoint_path=' + str(args.black_weight_path)], + stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - agent_v1 = subprocess.Popen(['python', '-u', 'player.py', '--role=white', '--checkpoint_path=' + str(args.white_weight_path)], - stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + agent_v1 = subprocess.Popen( + ['python', '-u', 'player.py', '--role=white', '--checkpoint_path=' + str(args.white_weight_path)], + stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) server_list = "" while ("black" not in server_list) or ("white" not in server_list): @@ -50,6 +67,7 @@ if __name__ == '__main__': print "Start black player at : " + str(agent_v0.pid) print "Start white player at : " + str(agent_v1.pid) + data = Data() player = [None] * 2 player[0] = Pyro4.Proxy("PYRONAME:black") player[1] = Pyro4.Proxy("PYRONAME:white") @@ -63,39 +81,58 @@ if __name__ == '__main__': evaluate_rounds = 1 game_num = 0 - while game_num < evaluate_rounds: - num = 0 - pass_flag = [False, False] - print("Start game {}".format(game_num)) - # end the game if both palyer chose to pass, or play too much turns - while not (pass_flag[0] and pass_flag[1]) and num < size ** 2 * 2: - turn = num % 2 - move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn] + '\n') - print role[turn] + " : " + str(move), - num += 1 - match = re.search(pattern, move) - if match is not None: - # print "match : " + str(match.group()) - play_or_pass = match.group() - pass_flag[turn] = False + try: + while True: + num = 0 + pass_flag = [False, False] + print("Start game {}".format(game_num)) + # end the game if both palyer chose to pass, or play too much turns + while not (pass_flag[0] and pass_flag[1]) and num < size ** 2 * 2: + turn = num % 2 + move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn] + '\n') + print role[turn] + " : " + str(move), + num += 1 + match = re.search(pattern, move) + if match is not None: + # print "match : " + str(match.group()) + play_or_pass = match.group() + pass_flag[turn] = False + else: + # print "no match" + play_or_pass = ' PASS' + pass_flag[turn] = True + result = player[1 - turn].run_cmd(str(num) + ' play ' + color[turn] + ' ' + play_or_pass + '\n') + board = player[turn].run_cmd(str(num) + ' show_board') + board = eval(board[board.index('['):board.index(']') + 1]) + for i in range(size): + for j in range(size): + print show[board[i * size + j]] + " ", + print "\n", + data.boards.append(board) + prob = player[turn].run_cmd(str(num) + ' get_prob') + data.probs.append(prob) + score = player[turn].run_cmd(str(num) + ' get_score') + print "Finished : ", score.split(" ")[1] + # TODO: generalize the player + if score > 0: + data.winner = 1 + if score < 0: + data.winner = -1 + player[0].run_cmd(str(num) + ' clear_board') + player[1].run_cmd(str(num) + ' clear_board') + file_list = os.listdir(args.result_path) + if not file_list: + data_num = 0 else: - # print "no match" - play_or_pass = ' PASS' - pass_flag[turn] = True - result = player[1 - turn].run_cmd(str(num) + ' play ' + color[turn] + ' ' + play_or_pass + '\n') - board = player[turn].run_cmd(str(num) + ' show_board') - board = eval(board[board.index('['):board.index(']') + 1]) - for i in range(size): - for j in range(size): - print show[board[i * size + j]] + " ", - print "\n", - - score = player[turn].run_cmd(str(num) + ' get_score') - print "Finished : ", score.split(" ")[1] - player[0].run_cmd(str(num) + ' clear_board') - player[1].run_cmd(str(num) + ' clear_board') - game_num += 1 - - subprocess.call(["kill", "-9", str(agent_v0.pid)]) - subprocess.call(["kill", "-9", str(agent_v1.pid)]) - print "Kill all player, finish all game." + file_list.sort(key=lambda file: os.path.getmtime(args.result_path + file) if not os.path.isdir( + args.result_path + file) else 0) + data_num = eval(file_list[-1][:-4]) + 1 + print(file_list) + with open("./data/" + str(data_num) + ".pkl", "w") as file: + picklestring = cPickle.dump(data, file) + data.reset() + game_num += 1 + except KeyboardInterrupt: + subprocess.call(["kill", "-9", str(agent_v0.pid)]) + subprocess.call(["kill", "-9", str(agent_v1.pid)]) + print "Kill all player, finish all game." diff --git a/AlphaGo/player.py b/AlphaGo/player.py index b468cf3..0e3daff 100644 --- a/AlphaGo/player.py +++ b/AlphaGo/player.py @@ -20,6 +20,7 @@ class Player(object): #return "inside the Player of player.py" return self.engine.run_cmd(command) + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--checkpoint_path", type=str, default=None)