implement data collection and part of training
This commit is contained in:
parent
1e2567c174
commit
2dad8e4020
@ -183,11 +183,15 @@ class GTPEngine():
|
|||||||
return 'unknown player', False
|
return 'unknown player', False
|
||||||
|
|
||||||
def cmd_get_score(self, args, **kwargs):
|
def cmd_get_score(self, args, **kwargs):
|
||||||
return self._game.game_engine.executor_get_score(True), None
|
return self._game.game_engine.executor_get_score(True), True
|
||||||
|
|
||||||
def cmd_show_board(self, args, **kwargs):
|
def cmd_show_board(self, args, **kwargs):
|
||||||
return self._game.board, True
|
return self._game.board, True
|
||||||
|
|
||||||
|
def cmd_get_prob(self, args, **kwargs):
|
||||||
|
return self._game.prob, True
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "main":
|
if __name__ == "main":
|
||||||
game = Game()
|
game = Game()
|
||||||
engine = GTPEngine(game_obj=Game)
|
engine = GTPEngine(game_obj=Game)
|
||||||
|
@ -58,24 +58,9 @@ class Game:
|
|||||||
def set_komi(self, k):
|
def set_komi(self, k):
|
||||||
self.komi = k
|
self.komi = k
|
||||||
|
|
||||||
def generate_nn_input(self, latest_boards, color):
|
|
||||||
state = np.zeros([1, self.size, self.size, 17])
|
|
||||||
for i in range(8):
|
|
||||||
state[0, :, :, i] = np.array(np.array(latest_boards[i]) == np.ones(self.size ** 2)).reshape(self.size, self.size)
|
|
||||||
state[0, :, :, i + 8] = np.array(np.array(latest_boards[i]) == -np.ones(self.size ** 2)).reshape(self.size, self.size)
|
|
||||||
if color == utils.BLACK:
|
|
||||||
state[0, :, :, 16] = np.ones([self.size, self.size])
|
|
||||||
if color == utils.WHITE:
|
|
||||||
state[0, :, :, 16] = np.zeros([self.size, self.size])
|
|
||||||
return state
|
|
||||||
|
|
||||||
def think(self, latest_boards, color):
|
def think(self, latest_boards, color):
|
||||||
# TODO : using copy is right, or should we change to deepcopy?
|
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], self.size ** 2 + 1, inverse=True)
|
||||||
self.game_engine.simulate_latest_boards = copy.copy(latest_boards)
|
mcts.search(max_step=1)
|
||||||
self.game_engine.simulate_board = copy.copy(latest_boards[-1])
|
|
||||||
nn_input = self.generate_nn_input(self.game_engine.simulate_latest_boards, color)
|
|
||||||
mcts = MCTS(self.game_engine, self.evaluator, [self.game_engine.simulate_latest_boards, color], self.size ** 2 + 1, inverse=True)
|
|
||||||
mcts.search(max_step=5)
|
|
||||||
temp = 1
|
temp = 1
|
||||||
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
||||||
choice = np.random.choice(self.size ** 2 + 1, 1, p=prob).tolist()[0]
|
choice = np.random.choice(self.size ** 2 + 1, 1, p=prob).tolist()[0]
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import sys
|
import sys
|
||||||
|
import cPickle
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@ -167,4 +168,19 @@ class ResNet(object):
|
|||||||
|
|
||||||
#TODO: design the interface between the environment and training
|
#TODO: design the interface between the environment and training
|
||||||
def train(self, mode='memory', *args, **kwargs):
|
def train(self, mode='memory', *args, **kwargs):
|
||||||
|
if mode == 'memory':
|
||||||
pass
|
pass
|
||||||
|
if mode == 'file':
|
||||||
|
self.train_with_file(data_path=kwargs['data_path'], checkpoint_path=kwargs['checkpoint_path'])
|
||||||
|
|
||||||
|
def train_with_file(self, data_path, checkpoint_path):
|
||||||
|
if not os.path.exists(data_path):
|
||||||
|
raise ValueError("{} doesn't exist".format(data_path))
|
||||||
|
|
||||||
|
file_list = os.listdir(data_path)
|
||||||
|
if file_list <= 50:
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
file_list.sort(key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir(
|
||||||
|
data_path + file) else 0)
|
||||||
|
|
||||||
|
@ -5,6 +5,18 @@ import re
|
|||||||
import Pyro4
|
import Pyro4
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
|
import cPickle
|
||||||
|
|
||||||
|
|
||||||
|
class Data(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.boards = []
|
||||||
|
self.probs = []
|
||||||
|
self.winner = 0
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.__init__()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
"""
|
"""
|
||||||
@ -13,10 +25,13 @@ if __name__ == '__main__':
|
|||||||
"""
|
"""
|
||||||
# TODO : we should set the network path in a more configurable way.
|
# TODO : we should set the network path in a more configurable way.
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--result_path", type=str, default="./data/")
|
||||||
parser.add_argument("--black_weight_path", type=str, default=None)
|
parser.add_argument("--black_weight_path", type=str, default=None)
|
||||||
parser.add_argument("--white_weight_path", type=str, default=None)
|
parser.add_argument("--white_weight_path", type=str, default=None)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.result_path):
|
||||||
|
os.mkdir(args.result_path)
|
||||||
# black_weight_path = "./checkpoints"
|
# black_weight_path = "./checkpoints"
|
||||||
# white_weight_path = "./checkpoints_origin"
|
# white_weight_path = "./checkpoints_origin"
|
||||||
if args.black_weight_path is not None and (not os.path.exists(args.black_weight_path)):
|
if args.black_weight_path is not None and (not os.path.exists(args.black_weight_path)):
|
||||||
@ -35,10 +50,12 @@ if __name__ == '__main__':
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
# start two different player with different network weights.
|
# start two different player with different network weights.
|
||||||
agent_v0 = subprocess.Popen(['python', '-u', 'player.py', '--role=black', '--checkpoint_path=' + str(args.black_weight_path)],
|
agent_v0 = subprocess.Popen(
|
||||||
|
['python', '-u', 'player.py', '--role=black', '--checkpoint_path=' + str(args.black_weight_path)],
|
||||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||||
|
|
||||||
agent_v1 = subprocess.Popen(['python', '-u', 'player.py', '--role=white', '--checkpoint_path=' + str(args.white_weight_path)],
|
agent_v1 = subprocess.Popen(
|
||||||
|
['python', '-u', 'player.py', '--role=white', '--checkpoint_path=' + str(args.white_weight_path)],
|
||||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||||
|
|
||||||
server_list = ""
|
server_list = ""
|
||||||
@ -50,6 +67,7 @@ if __name__ == '__main__':
|
|||||||
print "Start black player at : " + str(agent_v0.pid)
|
print "Start black player at : " + str(agent_v0.pid)
|
||||||
print "Start white player at : " + str(agent_v1.pid)
|
print "Start white player at : " + str(agent_v1.pid)
|
||||||
|
|
||||||
|
data = Data()
|
||||||
player = [None] * 2
|
player = [None] * 2
|
||||||
player[0] = Pyro4.Proxy("PYRONAME:black")
|
player[0] = Pyro4.Proxy("PYRONAME:black")
|
||||||
player[1] = Pyro4.Proxy("PYRONAME:white")
|
player[1] = Pyro4.Proxy("PYRONAME:white")
|
||||||
@ -63,7 +81,8 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
evaluate_rounds = 1
|
evaluate_rounds = 1
|
||||||
game_num = 0
|
game_num = 0
|
||||||
while game_num < evaluate_rounds:
|
try:
|
||||||
|
while True:
|
||||||
num = 0
|
num = 0
|
||||||
pass_flag = [False, False]
|
pass_flag = [False, False]
|
||||||
print("Start game {}".format(game_num))
|
print("Start game {}".format(game_num))
|
||||||
@ -89,13 +108,31 @@ if __name__ == '__main__':
|
|||||||
for j in range(size):
|
for j in range(size):
|
||||||
print show[board[i * size + j]] + " ",
|
print show[board[i * size + j]] + " ",
|
||||||
print "\n",
|
print "\n",
|
||||||
|
data.boards.append(board)
|
||||||
|
prob = player[turn].run_cmd(str(num) + ' get_prob')
|
||||||
|
data.probs.append(prob)
|
||||||
score = player[turn].run_cmd(str(num) + ' get_score')
|
score = player[turn].run_cmd(str(num) + ' get_score')
|
||||||
print "Finished : ", score.split(" ")[1]
|
print "Finished : ", score.split(" ")[1]
|
||||||
|
# TODO: generalize the player
|
||||||
|
if score > 0:
|
||||||
|
data.winner = 1
|
||||||
|
if score < 0:
|
||||||
|
data.winner = -1
|
||||||
player[0].run_cmd(str(num) + ' clear_board')
|
player[0].run_cmd(str(num) + ' clear_board')
|
||||||
player[1].run_cmd(str(num) + ' clear_board')
|
player[1].run_cmd(str(num) + ' clear_board')
|
||||||
|
file_list = os.listdir(args.result_path)
|
||||||
|
if not file_list:
|
||||||
|
data_num = 0
|
||||||
|
else:
|
||||||
|
file_list.sort(key=lambda file: os.path.getmtime(args.result_path + file) if not os.path.isdir(
|
||||||
|
args.result_path + file) else 0)
|
||||||
|
data_num = eval(file_list[-1][:-4]) + 1
|
||||||
|
print(file_list)
|
||||||
|
with open("./data/" + str(data_num) + ".pkl", "w") as file:
|
||||||
|
picklestring = cPickle.dump(data, file)
|
||||||
|
data.reset()
|
||||||
game_num += 1
|
game_num += 1
|
||||||
|
except KeyboardInterrupt:
|
||||||
subprocess.call(["kill", "-9", str(agent_v0.pid)])
|
subprocess.call(["kill", "-9", str(agent_v0.pid)])
|
||||||
subprocess.call(["kill", "-9", str(agent_v1.pid)])
|
subprocess.call(["kill", "-9", str(agent_v1.pid)])
|
||||||
print "Kill all player, finish all game."
|
print "Kill all player, finish all game."
|
||||||
|
@ -20,6 +20,7 @@ class Player(object):
|
|||||||
#return "inside the Player of player.py"
|
#return "inside the Player of player.py"
|
||||||
return self.engine.run_cmd(command)
|
return self.engine.run_cmd(command)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--checkpoint_path", type=str, default=None)
|
parser.add_argument("--checkpoint_path", type=str, default=None)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user