diff --git a/AlphaGo/play.py b/AlphaGo/play.py index c677c04..1066624 100644 --- a/AlphaGo/play.py +++ b/AlphaGo/play.py @@ -90,8 +90,9 @@ if __name__ == '__main__': parser.add_argument("--black_weight_path", type=str, default=None) parser.add_argument("--white_weight_path", type=str, default=None) parser.add_argument("--save_path", type=str, default="./go/") - parser.add_argument("--debug", type=bool, default=False) + parser.add_argument("--debug", action="store_true", default=False) parser.add_argument("--game", type=str, default="go") + parser.add_argument("--train", action="store_true", default=False) args = parser.parse_args() if not os.path.exists(args.data_path): @@ -113,7 +114,8 @@ if __name__ == '__main__': thread_train = threading.Thread(target=game.model.train, args=("file",), kwargs={'data_path':args.data_path, 'batch_size':128, 'save_path':args.save_path}) thread_play = threading.Thread(target=play, args=(engine, args.data_path)) - thread_list.append(thread_train) + if args.train: + thread_list.append(thread_train) thread_list.append(thread_play) for t in thread_list: