clean code

This commit is contained in:
rtz19970824 2018-03-14 19:17:28 +08:00
parent 52e6b09768
commit f70dfb0559

View File

@ -4,7 +4,6 @@ import sys
import re import re
import time import time
import os import os
import threading
from game import Game from game import Game
from engine import GTPEngine from engine import GTPEngine
from utils import Data from utils import Data
@ -104,7 +103,6 @@ if __name__ == '__main__':
parser.add_argument("--save_path", type=str, default="./go/") parser.add_argument("--save_path", type=str, default="./go/")
parser.add_argument("--debug", action="store_true", default=False) parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--game", type=str, default="go") parser.add_argument("--game", type=str, default="go")
parser.add_argument("--train", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
if not os.path.exists(args.data_path): if not os.path.exists(args.data_path):
@ -121,20 +119,4 @@ if __name__ == '__main__':
white_checkpoint_path=args.white_weight_path, white_checkpoint_path=args.white_weight_path,
debug=args.debug) debug=args.debug)
engine = GTPEngine(game_obj=game, name='tianshou', version=0) engine = GTPEngine(game_obj=game, name='tianshou', version=0)
play(engine, args.data_path)
thread_list = []
thread_train = threading.Thread(target=game.model.train, args=("file",),
kwargs={'data_path':args.data_path, 'batch_size':128, 'save_path':args.save_path})
thread_play = threading.Thread(target=play, args=(engine, args.data_path))
if args.train:
thread_list.append(thread_train)
thread_list.append(thread_play)
for t in thread_list:
t.daemon = True
t.start()
while True:
time.sleep(1)
#for t in thread_list:
# t.join()