Merge remote-tracking branch 'origin/master'

This commit is contained in:
haoshengzou 2018-03-28 18:47:54 +08:00
commit ace59787ed
2 changed files with 17 additions and 19 deletions

View File

@ -4,7 +4,6 @@ import sys
import re import re
import time import time
import os import os
import threading
from game import Game from game import Game
from engine import GTPEngine from engine import GTPEngine
from utils import Data from utils import Data
@ -104,7 +103,6 @@ if __name__ == '__main__':
parser.add_argument("--save_path", type=str, default="./go/") parser.add_argument("--save_path", type=str, default="./go/")
parser.add_argument("--debug", action="store_true", default=False) parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--game", type=str, default="go") parser.add_argument("--game", type=str, default="go")
parser.add_argument("--train", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
if not os.path.exists(args.data_path): if not os.path.exists(args.data_path):
@ -121,20 +119,4 @@ if __name__ == '__main__':
white_checkpoint_path=args.white_weight_path, white_checkpoint_path=args.white_weight_path,
debug=args.debug) debug=args.debug)
engine = GTPEngine(game_obj=game, name='tianshou', version=0) engine = GTPEngine(game_obj=game, name='tianshou', version=0)
play(engine, args.data_path)
thread_list = []
thread_train = threading.Thread(target=game.model.train, args=("file",),
kwargs={'data_path':args.data_path, 'batch_size':128, 'save_path':args.save_path})
thread_play = threading.Thread(target=play, args=(engine, args.data_path))
if args.train:
thread_list.append(thread_train)
thread_list.append(thread_play)
for t in thread_list:
t.daemon = True
t.start()
while True:
time.sleep(1)
#for t in thread_list:
# t.join()

16
AlphaGo/train.sh Normal file
View File

@ -0,0 +1,16 @@
GPU_play=(0)
GPU_train=(3,4,5,6)
str_play='python play.py --data_path=./data/ --save_path=./go/ --game=go &'
str_train='python model.py &'
play_each_GPU=4
$str_train
echo 'Start training'
for gpu in $GPU
do
export CUDA_VISIBLE_DEVICES=$gpu
for ((i=1;i<=$play_each_GPU;i++))
do
$str_play
done
done