Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
ace59787ed
@ -4,7 +4,6 @@ import sys
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import threading
|
|
||||||
from game import Game
|
from game import Game
|
||||||
from engine import GTPEngine
|
from engine import GTPEngine
|
||||||
from utils import Data
|
from utils import Data
|
||||||
@ -104,7 +103,6 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--save_path", type=str, default="./go/")
|
parser.add_argument("--save_path", type=str, default="./go/")
|
||||||
parser.add_argument("--debug", action="store_true", default=False)
|
parser.add_argument("--debug", action="store_true", default=False)
|
||||||
parser.add_argument("--game", type=str, default="go")
|
parser.add_argument("--game", type=str, default="go")
|
||||||
parser.add_argument("--train", action="store_true", default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not os.path.exists(args.data_path):
|
if not os.path.exists(args.data_path):
|
||||||
@ -121,20 +119,4 @@ if __name__ == '__main__':
|
|||||||
white_checkpoint_path=args.white_weight_path,
|
white_checkpoint_path=args.white_weight_path,
|
||||||
debug=args.debug)
|
debug=args.debug)
|
||||||
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
|
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
|
||||||
|
play(engine, args.data_path)
|
||||||
thread_list = []
|
|
||||||
thread_train = threading.Thread(target=game.model.train, args=("file",),
|
|
||||||
kwargs={'data_path':args.data_path, 'batch_size':128, 'save_path':args.save_path})
|
|
||||||
thread_play = threading.Thread(target=play, args=(engine, args.data_path))
|
|
||||||
if args.train:
|
|
||||||
thread_list.append(thread_train)
|
|
||||||
thread_list.append(thread_play)
|
|
||||||
|
|
||||||
for t in thread_list:
|
|
||||||
t.daemon = True
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
while True:
|
|
||||||
time.sleep(1)
|
|
||||||
#for t in thread_list:
|
|
||||||
# t.join()
|
|
||||||
|
16
AlphaGo/train.sh
Normal file
16
AlphaGo/train.sh
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
GPU_play=(0)
|
||||||
|
GPU_train=(3,4,5,6)
|
||||||
|
str_play='python play.py --data_path=./data/ --save_path=./go/ --game=go &'
|
||||||
|
str_train='python model.py &'
|
||||||
|
play_each_GPU=4
|
||||||
|
|
||||||
|
$str_train
|
||||||
|
echo 'Start training'
|
||||||
|
for gpu in $GPU
|
||||||
|
do
|
||||||
|
export CUDA_VISIBLE_DEVICES=$gpu
|
||||||
|
for ((i=1;i<=$play_each_GPU;i++))
|
||||||
|
do
|
||||||
|
$str_play
|
||||||
|
done
|
||||||
|
done
|
Loading…
x
Reference in New Issue
Block a user