add multi-thread for end-to-end training

This commit is contained in:
rtz19970824 2018-01-13 15:57:41 +08:00
parent fcaa571b42
commit 2e8662889f
2 changed files with 81 additions and 73 deletions

View File

@ -119,13 +119,12 @@ class ResNet(object):
zip(self.black_var_list, self.white_var_list)] zip(self.black_var_list, self.white_var_list)]
# training hyper-parameters: # training hyper-parameters:
self.window_length = 900 self.window_length = 500
self.save_freq = 5000 self.save_freq = 5000
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length), self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)} 'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
# training or not self.use_latest = False
self.training = False
def _build_network(self, scope, residual_block_num): def _build_network(self, scope, residual_block_num):
""" """
@ -188,16 +187,16 @@ class ResNet(object):
feed_dict={self.x: eval_state, self.is_training: False}) feed_dict={self.x: eval_state, self.is_training: False})
def check_latest_model(self): def check_latest_model(self):
if self.training: if self.use_latest:
black_ckpt_file = tf.train.latest_checkpoint(self.save_path + "black/") black_ckpt_file = tf.train.latest_checkpoint(self.save_path + "black/")
if self.black_ckpt_file != black_ckpt_file: if self.black_ckpt_file != black_ckpt_file and black_ckpt_file is not None:
self.black_ckpt_file = black_ckpt_file self.black_ckpt_file = black_ckpt_file
print('Loading model from {}...'.format(self.black_ckpt_file)) print('Loading model from {}...'.format(self.black_ckpt_file))
self.black_saver.restore(self.sess, self.black_ckpt_file) self.black_saver.restore(self.sess, self.black_ckpt_file)
print('Black Model Updated!') print('Black Model Updated!')
white_ckpt_file = tf.train.latest_checkpoint(self.save_path + "white/") white_ckpt_file = tf.train.latest_checkpoint(self.save_path + "white/")
if self.white_ckpt_file != white_ckpt_file: if self.white_ckpt_file != white_ckpt_file and white_ckpt_file is not None:
self.white_ckpt_file = white_ckpt_file self.white_ckpt_file = white_ckpt_file
print('Loading model from {}...'.format(self.white_ckpt_file)) print('Loading model from {}...'.format(self.white_ckpt_file))
self.white_saver.restore(self.sess, self.white_ckpt_file) self.white_saver.restore(self.sess, self.white_ckpt_file)
@ -234,7 +233,7 @@ class ResNet(object):
:param target: a string, which to optimize, can only be "both", "black" and "white" :param target: a string, which to optimize, can only be "both", "black" and "white"
:param mode: a string, how to optimize, can only be "memory" and "file" :param mode: a string, how to optimize, can only be "memory" and "file"
""" """
self.training = True self.use_latest = True
if mode == 'memory': if mode == 'memory':
pass pass
if mode == 'file': if mode == 'file':
@ -401,5 +400,5 @@ class ResNet(object):
if __name__ == "__main__": if __name__ == "__main__":
model = ResNet(board_size=8, action_num=65, history_length=1, black_checkpoint_path="./checkpoint/black", white_checkpoint_path="./checkpoint/white") model = ResNet(board_size=9, action_num=82, history_length=8, black_checkpoint_path="./checkpoint/black", white_checkpoint_path="./checkpoint/white")
model.train(mode="file", data_path="./data/", batch_size=128, save_path="./checkpoint/") model.train(mode="file", data_path="./data/", batch_size=128, save_path="./go-v2/")

View File

@ -3,6 +3,7 @@ import sys
import re import re
import time import time
import os import os
import threading
from game import Game from game import Game
from engine import GTPEngine from engine import GTPEngine
from utils import Data from utils import Data
@ -17,6 +18,67 @@ else:
import _pickle as cPickle import _pickle as cPickle
def play(engine, data_path):
data = Data()
role = ["BLACK", "WHITE"]
color = ['b', 'w']
pattern = "[A-Z]{1}[0-9]{1}"
space = re.compile("\s+")
size = {"go": 9, "reversi": 8}
show = ['.', 'X', 'O']
# evaluate_rounds = 100
game_num = 0
while True:
# while game_num < evaluate_rounds:
engine._game.model.check_latest_model()
num = 0
pass_flag = [False, False]
print("Start game {}".format(game_num))
# end the game if both palyer chose to pass, or play too much turns
while not (pass_flag[0] and pass_flag[1]) and num < size[engine._game.name] ** 2 * 2:
turn = num % 2
board = engine.run_cmd(str(num) + ' show_board')
board = eval(board[board.index('['):board.index(']') + 1])
for i in range(size[engine._game.name]):
for j in range(size[engine._game.name]):
print show[board[i * size[engine._game.name] + j]] + " ",
print "\n",
data.boards.append(board)
move = engine.run_cmd(str(num) + ' genmove ' + color[turn])[:-1]
print("\n" + role[turn] + " : " + str(move)),
num += 1
match = re.search(pattern, move)
if match is not None:
# print "match : " + str(match.group())
pass_flag[turn] = False
else:
# print "no match"
pass_flag[turn] = True
prob = engine.run_cmd(str(num) + ' get_prob')
prob = space.sub(',', prob[prob.index('['):prob.index(']') + 1])
prob = prob.replace('[,', '[')
prob = prob.replace('],', ']')
prob = eval(prob)
data.probs.append(prob)
score = engine.run_cmd(str(num) + ' get_score')
print("Finished : {}".format(score.split(" ")[1]))
if eval(score.split(" ")[1]) > 0:
data.winner = utils.BLACK
if eval(score.split(" ")[1]) < 0:
data.winner = utils.WHITE
engine.run_cmd(str(num) + ' clear_board')
current_time = strftime("%Y%m%d_%H%M%S", gmtime())
if os.path.exists(data_path + current_time + ".pkl"):
time.sleep(1)
current_time = strftime("%Y%m%d_%H%M%S", gmtime())
with open(data_path + current_time + ".pkl", "wb") as file:
cPickle.dump(data, file)
data.reset()
game_num += 1
if __name__ == '__main__': if __name__ == '__main__':
""" """
Starting two different players which load network weights to evaluate the winning ratio. Starting two different players which load network weights to evaluate the winning ratio.
@ -27,6 +89,7 @@ if __name__ == '__main__':
parser.add_argument("--data_path", type=str, default="./data/") parser.add_argument("--data_path", type=str, default="./data/")
parser.add_argument("--black_weight_path", type=str, default=None) parser.add_argument("--black_weight_path", type=str, default=None)
parser.add_argument("--white_weight_path", type=str, default=None) parser.add_argument("--white_weight_path", type=str, default=None)
parser.add_argument("--save_path", type=str, default="./go/")
parser.add_argument("--debug", type=bool, default=False) parser.add_argument("--debug", type=bool, default=False)
parser.add_argument("--game", type=str, default="go") parser.add_argument("--game", type=str, default="go")
args = parser.parse_args() args = parser.parse_args()
@ -46,69 +109,15 @@ if __name__ == '__main__':
debug=args.debug) debug=args.debug)
engine = GTPEngine(game_obj=game, name='tianshou', version=0) engine = GTPEngine(game_obj=game, name='tianshou', version=0)
data = Data() thread_list = []
role = ["BLACK", "WHITE"] thread_train = threading.Thread(target=game.model.train, args=("file",),
color = ['b', 'w'] kwargs={'data_path':args.data_path, 'batch_size':128, 'save_path':args.save_path})
thread_play = threading.Thread(target=play, args=(engine, args.data_path))
thread_list.append(thread_train)
thread_list.append(thread_play)
pattern = "[A-Z]{1}[0-9]{1}" for t in thread_list:
space = re.compile("\s+") t.start()
size = {"go":9, "reversi":8}
show = ['.', 'X', 'O']
evaluate_rounds = 100 for t in thread_list:
game_num = 0 t.join()
try:
while True:
#while game_num < evaluate_rounds:
start_time = time.time()
game.model.check_latest_model()
num = 0
pass_flag = [False, False]
print("Start game {}".format(game_num))
# end the game if both palyer chose to pass, or play too much turns
while not (pass_flag[0] and pass_flag[1]) and num < size[args.game] ** 2 * 2:
turn = num % 2
board = engine.run_cmd(str(num) + ' show_board')
board = eval(board[board.index('['):board.index(']') + 1])
for i in range(size[args.game]):
for j in range(size[args.game]):
print show[board[i * size[args.game] + j]] + " ",
print "\n",
data.boards.append(board)
start_time = time.time()
move = engine.run_cmd(str(num) + ' genmove ' + color[turn])[:-1]
print("\n" + role[turn] + " : " + str(move)),
num += 1
match = re.search(pattern, move)
if match is not None:
# print "match : " + str(match.group())
play_or_pass = match.group()
pass_flag[turn] = False
else:
# print "no match"
play_or_pass = ' PASS'
pass_flag[turn] = True
prob = engine.run_cmd(str(num) + ' get_prob')
prob = space.sub(',', prob[prob.index('['):prob.index(']') + 1])
prob = prob.replace('[,', '[')
prob = prob.replace('],', ']')
prob = eval(prob)
data.probs.append(prob)
score = engine.run_cmd(str(num) + ' get_score')
print("Finished : {}".format(score.split(" ")[1]))
if eval(score.split(" ")[1]) > 0:
data.winner = utils.BLACK
if eval(score.split(" ")[1]) < 0:
data.winner = utils.WHITE
engine.run_cmd(str(num) + ' clear_board')
file_list = os.listdir(args.data_path)
current_time = strftime("%Y%m%d_%H%M%S", gmtime())
if os.path.exists(args.data_path + current_time + ".pkl"):
time.sleep(1)
current_time = strftime("%Y%m%d_%H%M%S", gmtime())
with open(args.data_path + current_time + ".pkl", "wb") as file:
picklestring = cPickle.dump(data, file)
data.reset()
game_num += 1
except KeyboardInterrupt:
pass