2018-01-17 15:54:46 +08:00
|
|
|
from __future__ import division
|
2017-12-19 15:39:31 +08:00
|
|
|
import argparse
|
2017-12-09 21:41:11 +08:00
|
|
|
import sys
|
|
|
|
import re
|
|
|
|
import time
|
2017-12-16 14:55:19 +08:00
|
|
|
import os
|
2018-01-13 15:57:41 +08:00
|
|
|
import threading
|
2018-01-10 23:27:17 +08:00
|
|
|
from game import Game
|
|
|
|
from engine import GTPEngine
|
2018-01-11 17:02:36 +08:00
|
|
|
from utils import Data
|
2017-12-26 19:29:35 +08:00
|
|
|
import utils
|
|
|
|
from time import gmtime, strftime
|
2017-12-25 15:33:17 +08:00
|
|
|
|
|
|
|
python_version = sys.version_info
|
|
|
|
|
|
|
|
if python_version < (3, 0):
|
|
|
|
import cPickle
|
|
|
|
else:
|
|
|
|
import _pickle as cPickle
|
|
|
|
|
2018-01-09 20:09:48 +08:00
|
|
|
|
2018-01-13 15:57:41 +08:00
|
|
|
def play(engine, data_path):
|
|
|
|
data = Data()
|
|
|
|
role = ["BLACK", "WHITE"]
|
|
|
|
color = ['b', 'w']
|
|
|
|
|
|
|
|
pattern = "[A-Z]{1}[0-9]{1}"
|
|
|
|
space = re.compile("\s+")
|
|
|
|
size = {"go": 9, "reversi": 8}
|
|
|
|
show = ['.', 'X', 'O']
|
|
|
|
|
2018-01-17 15:54:46 +08:00
|
|
|
evaluate_rounds = 5
|
2018-01-13 15:57:41 +08:00
|
|
|
game_num = 0
|
2018-01-17 15:54:46 +08:00
|
|
|
total_time = 0
|
|
|
|
f=open('time.txt','w')
|
2018-01-16 14:10:56 +08:00
|
|
|
#while True:
|
|
|
|
while game_num < evaluate_rounds:
|
2018-01-17 15:54:46 +08:00
|
|
|
start = time.time()
|
2018-01-13 15:57:41 +08:00
|
|
|
engine._game.model.check_latest_model()
|
|
|
|
num = 0
|
|
|
|
pass_flag = [False, False]
|
|
|
|
print("Start game {}".format(game_num))
|
|
|
|
# end the game if both palyer chose to pass, or play too much turns
|
|
|
|
while not (pass_flag[0] and pass_flag[1]) and num < size[engine._game.name] ** 2 * 2:
|
|
|
|
turn = num % 2
|
|
|
|
board = engine.run_cmd(str(num) + ' show_board')
|
|
|
|
board = eval(board[board.index('['):board.index(']') + 1])
|
|
|
|
for i in range(size[engine._game.name]):
|
|
|
|
for j in range(size[engine._game.name]):
|
|
|
|
print show[board[i * size[engine._game.name] + j]] + " ",
|
|
|
|
print "\n",
|
|
|
|
data.boards.append(board)
|
|
|
|
move = engine.run_cmd(str(num) + ' genmove ' + color[turn])[:-1]
|
|
|
|
print("\n" + role[turn] + " : " + str(move)),
|
|
|
|
num += 1
|
|
|
|
match = re.search(pattern, move)
|
|
|
|
if match is not None:
|
|
|
|
# print "match : " + str(match.group())
|
|
|
|
pass_flag[turn] = False
|
|
|
|
else:
|
|
|
|
# print "no match"
|
|
|
|
pass_flag[turn] = True
|
|
|
|
prob = engine.run_cmd(str(num) + ' get_prob')
|
|
|
|
prob = space.sub(',', prob[prob.index('['):prob.index(']') + 1])
|
|
|
|
prob = prob.replace('[,', '[')
|
|
|
|
prob = prob.replace('],', ']')
|
|
|
|
prob = eval(prob)
|
|
|
|
data.probs.append(prob)
|
|
|
|
score = engine.run_cmd(str(num) + ' get_score')
|
|
|
|
print("Finished : {}".format(score.split(" ")[1]))
|
|
|
|
if eval(score.split(" ")[1]) > 0:
|
|
|
|
data.winner = utils.BLACK
|
|
|
|
if eval(score.split(" ")[1]) < 0:
|
|
|
|
data.winner = utils.WHITE
|
|
|
|
engine.run_cmd(str(num) + ' clear_board')
|
|
|
|
current_time = strftime("%Y%m%d_%H%M%S", gmtime())
|
|
|
|
if os.path.exists(data_path + current_time + ".pkl"):
|
|
|
|
time.sleep(1)
|
|
|
|
current_time = strftime("%Y%m%d_%H%M%S", gmtime())
|
|
|
|
with open(data_path + current_time + ".pkl", "wb") as file:
|
|
|
|
cPickle.dump(data, file)
|
|
|
|
data.reset()
|
|
|
|
game_num += 1
|
2018-01-17 15:54:46 +08:00
|
|
|
|
|
|
|
this_time = time.time() - start
|
|
|
|
total += this_time
|
|
|
|
f.write('time:'+ str(this_time)+'\n')
|
|
|
|
f.write('Avg time:' + str(total/evaluate_rounds))
|
|
|
|
f.close()
|
|
|
|
|
2018-01-13 15:57:41 +08:00
|
|
|
|
|
|
|
|
2017-12-16 14:33:31 +08:00
|
|
|
if __name__ == '__main__':
|
2017-12-16 14:55:19 +08:00
|
|
|
"""
|
|
|
|
Starting two different players which load network weights to evaluate the winning ratio.
|
|
|
|
Note that, this function requires the installation of the Pyro4 library.
|
|
|
|
"""
|
|
|
|
# TODO : we should set the network path in a more configurable way.
|
2017-12-19 15:39:31 +08:00
|
|
|
parser = argparse.ArgumentParser()
|
2017-12-25 16:40:38 +08:00
|
|
|
parser.add_argument("--data_path", type=str, default="./data/")
|
2017-12-19 15:39:31 +08:00
|
|
|
parser.add_argument("--black_weight_path", type=str, default=None)
|
|
|
|
parser.add_argument("--white_weight_path", type=str, default=None)
|
2018-01-13 15:57:41 +08:00
|
|
|
parser.add_argument("--save_path", type=str, default="./go/")
|
2018-01-13 15:59:57 +08:00
|
|
|
parser.add_argument("--debug", action="store_true", default=False)
|
2017-12-25 16:40:38 +08:00
|
|
|
parser.add_argument("--game", type=str, default="go")
|
2018-01-13 15:59:57 +08:00
|
|
|
parser.add_argument("--train", action="store_true", default=False)
|
2017-12-19 15:39:31 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
2017-12-25 16:40:38 +08:00
|
|
|
if not os.path.exists(args.data_path):
|
|
|
|
os.mkdir(args.data_path)
|
2017-12-19 15:39:31 +08:00
|
|
|
# black_weight_path = "./checkpoints"
|
|
|
|
# white_weight_path = "./checkpoints_origin"
|
|
|
|
if args.black_weight_path is not None and (not os.path.exists(args.black_weight_path)):
|
2017-12-26 19:29:35 +08:00
|
|
|
raise ValueError("Can't find the network weights for black player.")
|
2017-12-19 15:39:31 +08:00
|
|
|
if args.white_weight_path is not None and (not os.path.exists(args.white_weight_path)):
|
2017-12-26 19:29:35 +08:00
|
|
|
raise ValueError("Can't find the network weights for white player.")
|
2017-12-16 14:55:19 +08:00
|
|
|
|
2018-01-10 23:27:17 +08:00
|
|
|
game = Game(name=args.game,
|
|
|
|
black_checkpoint_path=args.black_weight_path,
|
|
|
|
white_checkpoint_path=args.white_weight_path,
|
|
|
|
debug=args.debug)
|
|
|
|
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
|
2017-12-09 21:41:11 +08:00
|
|
|
|
2018-01-13 15:57:41 +08:00
|
|
|
thread_list = []
|
|
|
|
thread_train = threading.Thread(target=game.model.train, args=("file",),
|
|
|
|
kwargs={'data_path':args.data_path, 'batch_size':128, 'save_path':args.save_path})
|
|
|
|
thread_play = threading.Thread(target=play, args=(engine, args.data_path))
|
2018-01-13 15:59:57 +08:00
|
|
|
if args.train:
|
|
|
|
thread_list.append(thread_train)
|
2018-01-13 15:57:41 +08:00
|
|
|
thread_list.append(thread_play)
|
2017-12-10 20:23:10 +08:00
|
|
|
|
2018-01-13 15:57:41 +08:00
|
|
|
for t in thread_list:
|
|
|
|
t.start()
|
2017-12-09 21:41:11 +08:00
|
|
|
|
2018-01-13 15:57:41 +08:00
|
|
|
for t in thread_list:
|
|
|
|
t.join()
|