This commit is contained in:
rtz19970824 2017-12-09 21:41:11 +08:00
parent 03a6880050
commit 1ff8252e6d
7 changed files with 143 additions and 31 deletions

View File

@ -182,6 +182,13 @@ class GTPEngine():
else:
return 'unknown player', False
def cmd_get_score(self, args, **kwargs):
return self._game.executor.get_score(), None
def cmd_show_board(self, args, **kwargs):
self._game.show_board()
return None, None
if __name__ == "main":
game = Game()

View File

@ -181,11 +181,11 @@ class Executor:
class Game:
def __init__(self, size=9, komi=6.5):
def __init__(self, size=9, komi=6.5, checkpoint_path=None):
self.size = size
self.komi = komi
self.board = [utils.EMPTY] * (self.size * self.size)
self.strategy = strategy()
self.strategy = strategy(checkpoint_path)
# self.strategy = None
self.executor = Executor(game=self)
self.history = []

View File

@ -18,6 +18,8 @@ FLAGS = tf.flags.FLAGS
def create_session():
config = tf.ConfigProto(allow_soft_placement=True,
log_device_placement=FLAGS.log_device_placement)
config.gpu_options.allow_growth = True
return tf.Session(config=config)

View File

@ -9,6 +9,7 @@ import tensorflow.contrib.layers as layers
import multi_gpu
import time
import copy
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
@ -80,14 +81,15 @@ class Network(object):
self.train_op = tf.train.RMSPropOptimizer(1e-4).minimize(self.total_loss)
self.var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
self.saver = tf.train.Saver(max_to_keep=10, var_list=self.var_list)
self.sess = multi_gpu.create_session()
def train(self):
data_path = "/home/tongzheng/data/"
data_name = os.listdir("/home/tongzheng/data/")
data_path = "./training_data/"
data_name = os.listdir(data_path)
epochs = 100
batch_size = 128
result_path = "./checkpoints/"
result_path = "./checkpoints_origin/"
with multi_gpu.create_session() as sess:
sess.run(tf.global_variables_initializer())
ckpt_file = tf.train.latest_checkpoint(result_path)
@ -112,7 +114,7 @@ class Network(object):
time_train = -time.time()
for iter in range(batch_num):
lv, lp, r, value, prob, _ = sess.run(
[self.value_loss, self.policy_loss, self.reg, self.v, tf.nn.softmax(p), self.train_op],
[self.value_loss, self.policy_loss, self.reg, self.v, tf.nn.softmax(self.p), self.train_op],
feed_dict={self.x: boards[
index[iter * batch_size:(iter + 1) * batch_size]],
self.z: wins[index[
@ -186,28 +188,35 @@ class Network(object):
# # np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ")
# return res
def forward(self):
def forward(self, checkpoint_path):
# checkpoint_path = "/home/tongzheng/tianshou/AlphaGo/checkpoints/"
sess = multi_gpu.create_session()
sess.run(tf.global_variables_initializer())
# ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
# if ckpt_file is not None:
# sess = multi_gpu.create_session()
# sess.run(tf.global_variables_initializer())
ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
if ckpt_file is not None:
# print('Restoring model from {}...'.format(ckpt_file))
# self.saver.restore(sess, ckpt_file)
self.saver.restore(self.sess, ckpt_file)
# print('Successfully loaded')
# else:
# raise ValueError("No model loaded")
else:
raise ValueError("No model loaded")
# prior, value = sess.run([tf.nn.softmax(p), v], feed_dict={x: state, is_training: False})
# return prior, value
return sess
return self.sess
if __name__ == '__main__':
state = np.random.randint(0, 1, [256, 9, 9, 17])
net = Network()
sess = net.forward()
start_time = time.time()
for i in range(100):
sess.run([tf.nn.softmax(net.p), net.v], feed_dict={net.x: state, net.is_training: False})
print("Step {}, use time {}".format(i, time.time() - start_time))
start_time = time.time()
# state = np.random.randint(0, 1, [256, 9, 9, 17])
# net = Network()
# net.train()
# sess = net.forward()
# start_time = time.time()
# for i in range(100):
# sess.run([tf.nn.softmax(net.p), net.v], feed_dict={net.x: state, net.is_training: False})
# print("Step {}, use time {}".format(i, time.time() - start_time))
# start_time = time.time()
net0 = Network()
sess0 = net0.forward("./checkpoints/")
print("Loaded")
while True:
pass

89
AlphaGo/play.py Normal file
View File

@ -0,0 +1,89 @@
import subprocess
import sys
import re
import time
pattern = "[A-Z]{1}[0-9]{1}"
size = 9
agent_v1 = subprocess.Popen(['python', '-u', 'test.py'], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
agent_v0 = subprocess.Popen(['python', '-u', 'test.py', '--checkpoint_path=./checkpoints_origin/'], stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
num = 0
game_num = 0
black_pass = False
white_pass = False
while game_num < 10:
print("Start game {}".format(game_num))
while not (black_pass and white_pass) and num < size ** 2 * 2:
print(num)
if num % 2 == 0:
print('BLACK TURN')
agent_v1.stdin.write(str(num) + ' genmove b\n')
agent_v1.stdin.flush()
result = agent_v1.stdout.readline()
sys.stdout.write(result)
sys.stdout.flush()
num += 1
match = re.search(pattern, result)
print("COPY BLACK")
if match is not None:
agent_v0.stdin.write(str(num) + ' play b ' + match.group() + '\n')
agent_v0.stdin.flush()
result = agent_v0.stdout.readline()
sys.stdout.flush()
else:
agent_v0.stdin.write(str(num) + ' play b PASS\n')
agent_v0.stdin.flush()
result = agent_v0.stdout.readline()
sys.stdout.flush()
if re.search("pass", result) is not None:
black_pass = True
else:
black_pass = False
else:
print('WHITE TURN')
agent_v0.stdin.write(str(num) + ' genmove w\n')
agent_v0.stdin.flush()
result = agent_v0.stdout.readline()
sys.stdout.write(result)
sys.stdout.flush()
num += 1
match = re.search(pattern, result)
print("COPY WHITE")
if match is not None:
agent_v1.stdin.write(str(num) + ' play w ' + match.group() + '\n')
agent_v1.stdin.flush()
result = agent_v1.stdout.readline()
sys.stdout.flush()
else:
agent_v1.stdin.write(str(num) + ' play w PASS\n')
agent_v1.stdin.flush()
result = agent_v1.stdout.readline()
sys.stdout.flush()
if re.search("pass", result) is not None:
black_pass = True
else:
black_pass = False
print("Finished")
print("\n")
agent_v1.stdin.write('clear_board\n')
agent_v1.stdin.flush()
result = agent_v1.stdout.readline()
sys.stdout.flush()
agent_v0.stdin.write('clear_board\n')
agent_v0.stdin.flush()
result = agent_v0.stdout.readline()
sys.stdout.flush()
agent_v1.stdin.write('get_score\n')
agent_v1.stdin.flush()
result = agent_v1.stdout.readline()
sys.stdout.write(result)
sys.stdout.flush()
game_num += 1

View File

@ -224,10 +224,10 @@ class GoEnv:
class strategy(object):
def __init__(self):
def __init__(self, checkpoint_path):
self.simulator = GoEnv()
self.net = network_small.Network()
self.sess = self.net.forward()
self.sess = self.net.forward(checkpoint_path)
self.evaluator = lambda state: self.sess.run([tf.nn.softmax(self.net.p), self.net.v],
feed_dict={self.net.x: state, self.net.is_training: False})

View File

@ -2,13 +2,18 @@ import sys
from game import Game
from engine import GTPEngine
# import utils
import argparse
import time
game = Game()
engine = GTPEngine(game_obj=game, name='tianshou')
cmd = raw_input
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, default="./checkpoints/")
args = parser.parse_args()
game = Game(checkpoint_path=args.checkpoint_path)
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
while not engine.disconnect:
command = cmd()
command = sys.stdin.readline()
result = engine.run_cmd(command)
sys.stdout.write(result)
sys.stdout.flush()