add multi-thread for end-to-end training
This commit is contained in:
parent
fcaa571b42
commit
2e8662889f
@ -119,13 +119,12 @@ class ResNet(object):
|
|||||||
zip(self.black_var_list, self.white_var_list)]
|
zip(self.black_var_list, self.white_var_list)]
|
||||||
|
|
||||||
# training hyper-parameters:
|
# training hyper-parameters:
|
||||||
self.window_length = 900
|
self.window_length = 500
|
||||||
self.save_freq = 5000
|
self.save_freq = 5000
|
||||||
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
|
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
|
||||||
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
|
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
|
||||||
|
|
||||||
# training or not
|
self.use_latest = False
|
||||||
self.training = False
|
|
||||||
|
|
||||||
def _build_network(self, scope, residual_block_num):
|
def _build_network(self, scope, residual_block_num):
|
||||||
"""
|
"""
|
||||||
@ -188,16 +187,16 @@ class ResNet(object):
|
|||||||
feed_dict={self.x: eval_state, self.is_training: False})
|
feed_dict={self.x: eval_state, self.is_training: False})
|
||||||
|
|
||||||
def check_latest_model(self):
|
def check_latest_model(self):
|
||||||
if self.training:
|
if self.use_latest:
|
||||||
black_ckpt_file = tf.train.latest_checkpoint(self.save_path + "black/")
|
black_ckpt_file = tf.train.latest_checkpoint(self.save_path + "black/")
|
||||||
if self.black_ckpt_file != black_ckpt_file:
|
if self.black_ckpt_file != black_ckpt_file and black_ckpt_file is not None:
|
||||||
self.black_ckpt_file = black_ckpt_file
|
self.black_ckpt_file = black_ckpt_file
|
||||||
print('Loading model from {}...'.format(self.black_ckpt_file))
|
print('Loading model from {}...'.format(self.black_ckpt_file))
|
||||||
self.black_saver.restore(self.sess, self.black_ckpt_file)
|
self.black_saver.restore(self.sess, self.black_ckpt_file)
|
||||||
print('Black Model Updated!')
|
print('Black Model Updated!')
|
||||||
|
|
||||||
white_ckpt_file = tf.train.latest_checkpoint(self.save_path + "white/")
|
white_ckpt_file = tf.train.latest_checkpoint(self.save_path + "white/")
|
||||||
if self.white_ckpt_file != white_ckpt_file:
|
if self.white_ckpt_file != white_ckpt_file and white_ckpt_file is not None:
|
||||||
self.white_ckpt_file = white_ckpt_file
|
self.white_ckpt_file = white_ckpt_file
|
||||||
print('Loading model from {}...'.format(self.white_ckpt_file))
|
print('Loading model from {}...'.format(self.white_ckpt_file))
|
||||||
self.white_saver.restore(self.sess, self.white_ckpt_file)
|
self.white_saver.restore(self.sess, self.white_ckpt_file)
|
||||||
@ -234,7 +233,7 @@ class ResNet(object):
|
|||||||
:param target: a string, which to optimize, can only be "both", "black" and "white"
|
:param target: a string, which to optimize, can only be "both", "black" and "white"
|
||||||
:param mode: a string, how to optimize, can only be "memory" and "file"
|
:param mode: a string, how to optimize, can only be "memory" and "file"
|
||||||
"""
|
"""
|
||||||
self.training = True
|
self.use_latest = True
|
||||||
if mode == 'memory':
|
if mode == 'memory':
|
||||||
pass
|
pass
|
||||||
if mode == 'file':
|
if mode == 'file':
|
||||||
@ -401,5 +400,5 @@ class ResNet(object):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model = ResNet(board_size=8, action_num=65, history_length=1, black_checkpoint_path="./checkpoint/black", white_checkpoint_path="./checkpoint/white")
|
model = ResNet(board_size=9, action_num=82, history_length=8, black_checkpoint_path="./checkpoint/black", white_checkpoint_path="./checkpoint/white")
|
||||||
model.train(mode="file", data_path="./data/", batch_size=128, save_path="./checkpoint/")
|
model.train(mode="file", data_path="./data/", batch_size=128, save_path="./go-v2/")
|
||||||
|
137
AlphaGo/play.py
137
AlphaGo/play.py
@ -3,6 +3,7 @@ import sys
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
|
import threading
|
||||||
from game import Game
|
from game import Game
|
||||||
from engine import GTPEngine
|
from engine import GTPEngine
|
||||||
from utils import Data
|
from utils import Data
|
||||||
@ -17,6 +18,67 @@ else:
|
|||||||
import _pickle as cPickle
|
import _pickle as cPickle
|
||||||
|
|
||||||
|
|
||||||
|
def play(engine, data_path):
|
||||||
|
data = Data()
|
||||||
|
role = ["BLACK", "WHITE"]
|
||||||
|
color = ['b', 'w']
|
||||||
|
|
||||||
|
pattern = "[A-Z]{1}[0-9]{1}"
|
||||||
|
space = re.compile("\s+")
|
||||||
|
size = {"go": 9, "reversi": 8}
|
||||||
|
show = ['.', 'X', 'O']
|
||||||
|
|
||||||
|
# evaluate_rounds = 100
|
||||||
|
game_num = 0
|
||||||
|
while True:
|
||||||
|
# while game_num < evaluate_rounds:
|
||||||
|
engine._game.model.check_latest_model()
|
||||||
|
num = 0
|
||||||
|
pass_flag = [False, False]
|
||||||
|
print("Start game {}".format(game_num))
|
||||||
|
# end the game if both palyer chose to pass, or play too much turns
|
||||||
|
while not (pass_flag[0] and pass_flag[1]) and num < size[engine._game.name] ** 2 * 2:
|
||||||
|
turn = num % 2
|
||||||
|
board = engine.run_cmd(str(num) + ' show_board')
|
||||||
|
board = eval(board[board.index('['):board.index(']') + 1])
|
||||||
|
for i in range(size[engine._game.name]):
|
||||||
|
for j in range(size[engine._game.name]):
|
||||||
|
print show[board[i * size[engine._game.name] + j]] + " ",
|
||||||
|
print "\n",
|
||||||
|
data.boards.append(board)
|
||||||
|
move = engine.run_cmd(str(num) + ' genmove ' + color[turn])[:-1]
|
||||||
|
print("\n" + role[turn] + " : " + str(move)),
|
||||||
|
num += 1
|
||||||
|
match = re.search(pattern, move)
|
||||||
|
if match is not None:
|
||||||
|
# print "match : " + str(match.group())
|
||||||
|
pass_flag[turn] = False
|
||||||
|
else:
|
||||||
|
# print "no match"
|
||||||
|
pass_flag[turn] = True
|
||||||
|
prob = engine.run_cmd(str(num) + ' get_prob')
|
||||||
|
prob = space.sub(',', prob[prob.index('['):prob.index(']') + 1])
|
||||||
|
prob = prob.replace('[,', '[')
|
||||||
|
prob = prob.replace('],', ']')
|
||||||
|
prob = eval(prob)
|
||||||
|
data.probs.append(prob)
|
||||||
|
score = engine.run_cmd(str(num) + ' get_score')
|
||||||
|
print("Finished : {}".format(score.split(" ")[1]))
|
||||||
|
if eval(score.split(" ")[1]) > 0:
|
||||||
|
data.winner = utils.BLACK
|
||||||
|
if eval(score.split(" ")[1]) < 0:
|
||||||
|
data.winner = utils.WHITE
|
||||||
|
engine.run_cmd(str(num) + ' clear_board')
|
||||||
|
current_time = strftime("%Y%m%d_%H%M%S", gmtime())
|
||||||
|
if os.path.exists(data_path + current_time + ".pkl"):
|
||||||
|
time.sleep(1)
|
||||||
|
current_time = strftime("%Y%m%d_%H%M%S", gmtime())
|
||||||
|
with open(data_path + current_time + ".pkl", "wb") as file:
|
||||||
|
cPickle.dump(data, file)
|
||||||
|
data.reset()
|
||||||
|
game_num += 1
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
"""
|
"""
|
||||||
Starting two different players which load network weights to evaluate the winning ratio.
|
Starting two different players which load network weights to evaluate the winning ratio.
|
||||||
@ -27,6 +89,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--data_path", type=str, default="./data/")
|
parser.add_argument("--data_path", type=str, default="./data/")
|
||||||
parser.add_argument("--black_weight_path", type=str, default=None)
|
parser.add_argument("--black_weight_path", type=str, default=None)
|
||||||
parser.add_argument("--white_weight_path", type=str, default=None)
|
parser.add_argument("--white_weight_path", type=str, default=None)
|
||||||
|
parser.add_argument("--save_path", type=str, default="./go/")
|
||||||
parser.add_argument("--debug", type=bool, default=False)
|
parser.add_argument("--debug", type=bool, default=False)
|
||||||
parser.add_argument("--game", type=str, default="go")
|
parser.add_argument("--game", type=str, default="go")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -46,69 +109,15 @@ if __name__ == '__main__':
|
|||||||
debug=args.debug)
|
debug=args.debug)
|
||||||
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
|
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
|
||||||
|
|
||||||
data = Data()
|
thread_list = []
|
||||||
role = ["BLACK", "WHITE"]
|
thread_train = threading.Thread(target=game.model.train, args=("file",),
|
||||||
color = ['b', 'w']
|
kwargs={'data_path':args.data_path, 'batch_size':128, 'save_path':args.save_path})
|
||||||
|
thread_play = threading.Thread(target=play, args=(engine, args.data_path))
|
||||||
|
thread_list.append(thread_train)
|
||||||
|
thread_list.append(thread_play)
|
||||||
|
|
||||||
pattern = "[A-Z]{1}[0-9]{1}"
|
for t in thread_list:
|
||||||
space = re.compile("\s+")
|
t.start()
|
||||||
size = {"go":9, "reversi":8}
|
|
||||||
show = ['.', 'X', 'O']
|
|
||||||
|
|
||||||
evaluate_rounds = 100
|
for t in thread_list:
|
||||||
game_num = 0
|
t.join()
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
#while game_num < evaluate_rounds:
|
|
||||||
start_time = time.time()
|
|
||||||
game.model.check_latest_model()
|
|
||||||
num = 0
|
|
||||||
pass_flag = [False, False]
|
|
||||||
print("Start game {}".format(game_num))
|
|
||||||
# end the game if both palyer chose to pass, or play too much turns
|
|
||||||
while not (pass_flag[0] and pass_flag[1]) and num < size[args.game] ** 2 * 2:
|
|
||||||
turn = num % 2
|
|
||||||
board = engine.run_cmd(str(num) + ' show_board')
|
|
||||||
board = eval(board[board.index('['):board.index(']') + 1])
|
|
||||||
for i in range(size[args.game]):
|
|
||||||
for j in range(size[args.game]):
|
|
||||||
print show[board[i * size[args.game] + j]] + " ",
|
|
||||||
print "\n",
|
|
||||||
data.boards.append(board)
|
|
||||||
start_time = time.time()
|
|
||||||
move = engine.run_cmd(str(num) + ' genmove ' + color[turn])[:-1]
|
|
||||||
print("\n" + role[turn] + " : " + str(move)),
|
|
||||||
num += 1
|
|
||||||
match = re.search(pattern, move)
|
|
||||||
if match is not None:
|
|
||||||
# print "match : " + str(match.group())
|
|
||||||
play_or_pass = match.group()
|
|
||||||
pass_flag[turn] = False
|
|
||||||
else:
|
|
||||||
# print "no match"
|
|
||||||
play_or_pass = ' PASS'
|
|
||||||
pass_flag[turn] = True
|
|
||||||
prob = engine.run_cmd(str(num) + ' get_prob')
|
|
||||||
prob = space.sub(',', prob[prob.index('['):prob.index(']') + 1])
|
|
||||||
prob = prob.replace('[,', '[')
|
|
||||||
prob = prob.replace('],', ']')
|
|
||||||
prob = eval(prob)
|
|
||||||
data.probs.append(prob)
|
|
||||||
score = engine.run_cmd(str(num) + ' get_score')
|
|
||||||
print("Finished : {}".format(score.split(" ")[1]))
|
|
||||||
if eval(score.split(" ")[1]) > 0:
|
|
||||||
data.winner = utils.BLACK
|
|
||||||
if eval(score.split(" ")[1]) < 0:
|
|
||||||
data.winner = utils.WHITE
|
|
||||||
engine.run_cmd(str(num) + ' clear_board')
|
|
||||||
file_list = os.listdir(args.data_path)
|
|
||||||
current_time = strftime("%Y%m%d_%H%M%S", gmtime())
|
|
||||||
if os.path.exists(args.data_path + current_time + ".pkl"):
|
|
||||||
time.sleep(1)
|
|
||||||
current_time = strftime("%Y%m%d_%H%M%S", gmtime())
|
|
||||||
with open(args.data_path + current_time + ".pkl", "wb") as file:
|
|
||||||
picklestring = cPickle.dump(data, file)
|
|
||||||
data.reset()
|
|
||||||
game_num += 1
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user