implement data collection and part of training

This commit is contained in:
rtz19970824 2017-12-21 21:01:25 +08:00
parent ced63af18f
commit eda7ed07a1
5 changed files with 101 additions and 58 deletions

View File

@ -183,11 +183,15 @@ class GTPEngine():
return 'unknown player', False return 'unknown player', False
def cmd_get_score(self, args, **kwargs): def cmd_get_score(self, args, **kwargs):
return self._game.game_engine.executor_get_score(True), None return self._game.game_engine.executor_get_score(True), True
def cmd_show_board(self, args, **kwargs): def cmd_show_board(self, args, **kwargs):
return self._game.board, True return self._game.board, True
def cmd_get_prob(self, args, **kwargs):
return self._game.prob, True
if __name__ == "main": if __name__ == "main":
game = Game() game = Game()
engine = GTPEngine(game_obj=Game) engine = GTPEngine(game_obj=Game)

View File

@ -58,24 +58,9 @@ class Game:
def set_komi(self, k): def set_komi(self, k):
self.komi = k self.komi = k
def generate_nn_input(self, latest_boards, color):
state = np.zeros([1, self.size, self.size, 17])
for i in range(8):
state[0, :, :, i] = np.array(np.array(latest_boards[i]) == np.ones(self.size ** 2)).reshape(self.size, self.size)
state[0, :, :, i + 8] = np.array(np.array(latest_boards[i]) == -np.ones(self.size ** 2)).reshape(self.size, self.size)
if color == utils.BLACK:
state[0, :, :, 16] = np.ones([self.size, self.size])
if color == utils.WHITE:
state[0, :, :, 16] = np.zeros([self.size, self.size])
return state
def think(self, latest_boards, color): def think(self, latest_boards, color):
# TODO : using copy is right, or should we change to deepcopy? mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], self.size ** 2 + 1, inverse=True)
self.game_engine.simulate_latest_boards = copy.copy(latest_boards) mcts.search(max_step=1)
self.game_engine.simulate_board = copy.copy(latest_boards[-1])
nn_input = self.generate_nn_input(self.game_engine.simulate_latest_boards, color)
mcts = MCTS(self.game_engine, self.evaluator, [self.game_engine.simulate_latest_boards, color], self.size ** 2 + 1, inverse=True)
mcts.search(max_step=5)
temp = 1 temp = 1
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp) prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
choice = np.random.choice(self.size ** 2 + 1, 1, p=prob).tolist()[0] choice = np.random.choice(self.size ** 2 + 1, 1, p=prob).tolist()[0]

View File

@ -1,6 +1,7 @@
import os import os
import time import time
import sys import sys
import cPickle
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -167,4 +168,19 @@ class ResNet(object):
#TODO: design the interface between the environment and training #TODO: design the interface between the environment and training
def train(self, mode='memory', *args, **kwargs): def train(self, mode='memory', *args, **kwargs):
pass if mode == 'memory':
pass
if mode == 'file':
self.train_with_file(data_path=kwargs['data_path'], checkpoint_path=kwargs['checkpoint_path'])
def train_with_file(self, data_path, checkpoint_path):
if not os.path.exists(data_path):
raise ValueError("{} doesn't exist".format(data_path))
file_list = os.listdir(data_path)
if file_list <= 50:
time.sleep(1)
else:
file_list.sort(key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir(
data_path + file) else 0)

View File

@ -5,6 +5,18 @@ import re
import Pyro4 import Pyro4
import time import time
import os import os
import cPickle
class Data(object):
def __init__(self):
self.boards = []
self.probs = []
self.winner = 0
def reset(self):
self.__init__()
if __name__ == '__main__': if __name__ == '__main__':
""" """
@ -13,10 +25,13 @@ if __name__ == '__main__':
""" """
# TODO : we should set the network path in a more configurable way. # TODO : we should set the network path in a more configurable way.
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--result_path", type=str, default="./data/")
parser.add_argument("--black_weight_path", type=str, default=None) parser.add_argument("--black_weight_path", type=str, default=None)
parser.add_argument("--white_weight_path", type=str, default=None) parser.add_argument("--white_weight_path", type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
if not os.path.exists(args.result_path):
os.mkdir(args.result_path)
# black_weight_path = "./checkpoints" # black_weight_path = "./checkpoints"
# white_weight_path = "./checkpoints_origin" # white_weight_path = "./checkpoints_origin"
if args.black_weight_path is not None and (not os.path.exists(args.black_weight_path)): if args.black_weight_path is not None and (not os.path.exists(args.black_weight_path)):
@ -35,11 +50,13 @@ if __name__ == '__main__':
time.sleep(1) time.sleep(1)
# start two different player with different network weights. # start two different player with different network weights.
agent_v0 = subprocess.Popen(['python', '-u', 'player.py', '--role=black', '--checkpoint_path=' + str(args.black_weight_path)], agent_v0 = subprocess.Popen(
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) ['python', '-u', 'player.py', '--role=black', '--checkpoint_path=' + str(args.black_weight_path)],
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
agent_v1 = subprocess.Popen(['python', '-u', 'player.py', '--role=white', '--checkpoint_path=' + str(args.white_weight_path)], agent_v1 = subprocess.Popen(
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) ['python', '-u', 'player.py', '--role=white', '--checkpoint_path=' + str(args.white_weight_path)],
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
server_list = "" server_list = ""
while ("black" not in server_list) or ("white" not in server_list): while ("black" not in server_list) or ("white" not in server_list):
@ -50,6 +67,7 @@ if __name__ == '__main__':
print "Start black player at : " + str(agent_v0.pid) print "Start black player at : " + str(agent_v0.pid)
print "Start white player at : " + str(agent_v1.pid) print "Start white player at : " + str(agent_v1.pid)
data = Data()
player = [None] * 2 player = [None] * 2
player[0] = Pyro4.Proxy("PYRONAME:black") player[0] = Pyro4.Proxy("PYRONAME:black")
player[1] = Pyro4.Proxy("PYRONAME:white") player[1] = Pyro4.Proxy("PYRONAME:white")
@ -63,39 +81,58 @@ if __name__ == '__main__':
evaluate_rounds = 1 evaluate_rounds = 1
game_num = 0 game_num = 0
while game_num < evaluate_rounds: try:
num = 0 while True:
pass_flag = [False, False] num = 0
print("Start game {}".format(game_num)) pass_flag = [False, False]
# end the game if both palyer chose to pass, or play too much turns print("Start game {}".format(game_num))
while not (pass_flag[0] and pass_flag[1]) and num < size ** 2 * 2: # end the game if both palyer chose to pass, or play too much turns
turn = num % 2 while not (pass_flag[0] and pass_flag[1]) and num < size ** 2 * 2:
move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn] + '\n') turn = num % 2
print role[turn] + " : " + str(move), move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn] + '\n')
num += 1 print role[turn] + " : " + str(move),
match = re.search(pattern, move) num += 1
if match is not None: match = re.search(pattern, move)
# print "match : " + str(match.group()) if match is not None:
play_or_pass = match.group() # print "match : " + str(match.group())
pass_flag[turn] = False play_or_pass = match.group()
pass_flag[turn] = False
else:
# print "no match"
play_or_pass = ' PASS'
pass_flag[turn] = True
result = player[1 - turn].run_cmd(str(num) + ' play ' + color[turn] + ' ' + play_or_pass + '\n')
board = player[turn].run_cmd(str(num) + ' show_board')
board = eval(board[board.index('['):board.index(']') + 1])
for i in range(size):
for j in range(size):
print show[board[i * size + j]] + " ",
print "\n",
data.boards.append(board)
prob = player[turn].run_cmd(str(num) + ' get_prob')
data.probs.append(prob)
score = player[turn].run_cmd(str(num) + ' get_score')
print "Finished : ", score.split(" ")[1]
# TODO: generalize the player
if score > 0:
data.winner = 1
if score < 0:
data.winner = -1
player[0].run_cmd(str(num) + ' clear_board')
player[1].run_cmd(str(num) + ' clear_board')
file_list = os.listdir(args.result_path)
if not file_list:
data_num = 0
else: else:
# print "no match" file_list.sort(key=lambda file: os.path.getmtime(args.result_path + file) if not os.path.isdir(
play_or_pass = ' PASS' args.result_path + file) else 0)
pass_flag[turn] = True data_num = eval(file_list[-1][:-4]) + 1
result = player[1 - turn].run_cmd(str(num) + ' play ' + color[turn] + ' ' + play_or_pass + '\n') print(file_list)
board = player[turn].run_cmd(str(num) + ' show_board') with open("./data/" + str(data_num) + ".pkl", "w") as file:
board = eval(board[board.index('['):board.index(']') + 1]) picklestring = cPickle.dump(data, file)
for i in range(size): data.reset()
for j in range(size): game_num += 1
print show[board[i * size + j]] + " ", except KeyboardInterrupt:
print "\n", subprocess.call(["kill", "-9", str(agent_v0.pid)])
subprocess.call(["kill", "-9", str(agent_v1.pid)])
score = player[turn].run_cmd(str(num) + ' get_score') print "Kill all player, finish all game."
print "Finished : ", score.split(" ")[1]
player[0].run_cmd(str(num) + ' clear_board')
player[1].run_cmd(str(num) + ' clear_board')
game_num += 1
subprocess.call(["kill", "-9", str(agent_v0.pid)])
subprocess.call(["kill", "-9", str(agent_v1.pid)])
print "Kill all player, finish all game."

View File

@ -20,6 +20,7 @@ class Player(object):
#return "inside the Player of player.py" #return "inside the Player of player.py"
return self.engine.run_cmd(command) return self.engine.run_cmd(command)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, default=None) parser.add_argument("--checkpoint_path", type=str, default=None)