Merge remote-tracking branch 'origin/master'

This commit is contained in:
haoshengzou 2018-03-28 18:47:54 +08:00
commit ace59787ed
2 changed files with 17 additions and 19 deletions

View File

@ -4,7 +4,6 @@ import sys
import re
import time
import os
import threading
from game import Game
from engine import GTPEngine
from utils import Data
@ -104,7 +103,6 @@ if __name__ == '__main__':
parser.add_argument("--save_path", type=str, default="./go/")
parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--game", type=str, default="go")
parser.add_argument("--train", action="store_true", default=False)
args = parser.parse_args()
if not os.path.exists(args.data_path):
@ -121,20 +119,4 @@ if __name__ == '__main__':
white_checkpoint_path=args.white_weight_path,
debug=args.debug)
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
thread_list = []
thread_train = threading.Thread(target=game.model.train, args=("file",),
kwargs={'data_path':args.data_path, 'batch_size':128, 'save_path':args.save_path})
thread_play = threading.Thread(target=play, args=(engine, args.data_path))
if args.train:
thread_list.append(thread_train)
thread_list.append(thread_play)
for t in thread_list:
t.daemon = True
t.start()
while True:
time.sleep(1)
#for t in thread_list:
# t.join()
play(engine, args.data_path)

16
AlphaGo/train.sh Normal file
View File

@ -0,0 +1,16 @@
GPU_play=(0)
GPU_train=(3,4,5,6)
str_play='python play.py --data_path=./data/ --save_path=./go/ --game=go &'
str_train='python model.py &'
play_each_GPU=4
$str_train
echo 'Start training'
for gpu in $GPU
do
export CUDA_VISIBLE_DEVICES=$gpu
for ((i=1;i<=$play_each_GPU;i++))
do
$str_play
done
done