diff --git a/AlphaGo/play.py b/AlphaGo/play.py index f6740ca..67fbb7c 100644 --- a/AlphaGo/play.py +++ b/AlphaGo/play.py @@ -4,7 +4,6 @@ import sys import re import time import os -import threading from game import Game from engine import GTPEngine from utils import Data @@ -104,7 +103,6 @@ if __name__ == '__main__': parser.add_argument("--save_path", type=str, default="./go/") parser.add_argument("--debug", action="store_true", default=False) parser.add_argument("--game", type=str, default="go") - parser.add_argument("--train", action="store_true", default=False) args = parser.parse_args() if not os.path.exists(args.data_path): @@ -121,20 +119,4 @@ if __name__ == '__main__': white_checkpoint_path=args.white_weight_path, debug=args.debug) engine = GTPEngine(game_obj=game, name='tianshou', version=0) - - thread_list = [] - thread_train = threading.Thread(target=game.model.train, args=("file",), - kwargs={'data_path':args.data_path, 'batch_size':128, 'save_path':args.save_path}) - thread_play = threading.Thread(target=play, args=(engine, args.data_path)) - if args.train: - thread_list.append(thread_train) - thread_list.append(thread_play) - - for t in thread_list: - t.daemon = True - t.start() - - while True: - time.sleep(1) - #for t in thread_list: - # t.join() + play(engine, args.data_path) diff --git a/AlphaGo/train.sh b/AlphaGo/train.sh new file mode 100644 index 0000000..aecf1cd --- /dev/null +++ b/AlphaGo/train.sh @@ -0,0 +1,16 @@ +GPU_play=(0) +GPU_train=(3,4,5,6) +str_play='python play.py --data_path=./data/ --save_path=./go/ --game=go &' +str_train='python model.py &' +play_each_GPU=4 + +$str_train +echo 'Start training' +for gpu in $GPU +do +export CUDA_VISIBLE_DEVICES=$gpu +for ((i=1;i<=$play_each_GPU;i++)) +do +$str_play +done +done