play
This commit is contained in:
parent
03a6880050
commit
1ff8252e6d
@ -182,6 +182,13 @@ class GTPEngine():
|
|||||||
else:
|
else:
|
||||||
return 'unknown player', False
|
return 'unknown player', False
|
||||||
|
|
||||||
|
def cmd_get_score(self, args, **kwargs):
|
||||||
|
return self._game.executor.get_score(), None
|
||||||
|
|
||||||
|
def cmd_show_board(self, args, **kwargs):
|
||||||
|
self._game.show_board()
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "main":
|
if __name__ == "main":
|
||||||
game = Game()
|
game = Game()
|
||||||
|
@ -181,11 +181,11 @@ class Executor:
|
|||||||
|
|
||||||
|
|
||||||
class Game:
|
class Game:
|
||||||
def __init__(self, size=9, komi=6.5):
|
def __init__(self, size=9, komi=6.5, checkpoint_path=None):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.komi = komi
|
self.komi = komi
|
||||||
self.board = [utils.EMPTY] * (self.size * self.size)
|
self.board = [utils.EMPTY] * (self.size * self.size)
|
||||||
self.strategy = strategy()
|
self.strategy = strategy(checkpoint_path)
|
||||||
# self.strategy = None
|
# self.strategy = None
|
||||||
self.executor = Executor(game=self)
|
self.executor = Executor(game=self)
|
||||||
self.history = []
|
self.history = []
|
||||||
|
@ -18,6 +18,8 @@ FLAGS = tf.flags.FLAGS
|
|||||||
def create_session():
|
def create_session():
|
||||||
config = tf.ConfigProto(allow_soft_placement=True,
|
config = tf.ConfigProto(allow_soft_placement=True,
|
||||||
log_device_placement=FLAGS.log_device_placement)
|
log_device_placement=FLAGS.log_device_placement)
|
||||||
|
config.gpu_options.allow_growth = True
|
||||||
|
|
||||||
return tf.Session(config=config)
|
return tf.Session(config=config)
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ import tensorflow.contrib.layers as layers
|
|||||||
|
|
||||||
import multi_gpu
|
import multi_gpu
|
||||||
import time
|
import time
|
||||||
|
import copy
|
||||||
|
|
||||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
@ -80,14 +81,15 @@ class Network(object):
|
|||||||
self.train_op = tf.train.RMSPropOptimizer(1e-4).minimize(self.total_loss)
|
self.train_op = tf.train.RMSPropOptimizer(1e-4).minimize(self.total_loss)
|
||||||
self.var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
self.var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.saver = tf.train.Saver(max_to_keep=10, var_list=self.var_list)
|
self.saver = tf.train.Saver(max_to_keep=10, var_list=self.var_list)
|
||||||
|
self.sess = multi_gpu.create_session()
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
data_path = "/home/tongzheng/data/"
|
data_path = "./training_data/"
|
||||||
data_name = os.listdir("/home/tongzheng/data/")
|
data_name = os.listdir(data_path)
|
||||||
epochs = 100
|
epochs = 100
|
||||||
batch_size = 128
|
batch_size = 128
|
||||||
|
|
||||||
result_path = "./checkpoints/"
|
result_path = "./checkpoints_origin/"
|
||||||
with multi_gpu.create_session() as sess:
|
with multi_gpu.create_session() as sess:
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
ckpt_file = tf.train.latest_checkpoint(result_path)
|
ckpt_file = tf.train.latest_checkpoint(result_path)
|
||||||
@ -112,7 +114,7 @@ class Network(object):
|
|||||||
time_train = -time.time()
|
time_train = -time.time()
|
||||||
for iter in range(batch_num):
|
for iter in range(batch_num):
|
||||||
lv, lp, r, value, prob, _ = sess.run(
|
lv, lp, r, value, prob, _ = sess.run(
|
||||||
[self.value_loss, self.policy_loss, self.reg, self.v, tf.nn.softmax(p), self.train_op],
|
[self.value_loss, self.policy_loss, self.reg, self.v, tf.nn.softmax(self.p), self.train_op],
|
||||||
feed_dict={self.x: boards[
|
feed_dict={self.x: boards[
|
||||||
index[iter * batch_size:(iter + 1) * batch_size]],
|
index[iter * batch_size:(iter + 1) * batch_size]],
|
||||||
self.z: wins[index[
|
self.z: wins[index[
|
||||||
@ -186,28 +188,35 @@ class Network(object):
|
|||||||
# # np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ")
|
# # np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ")
|
||||||
# return res
|
# return res
|
||||||
|
|
||||||
def forward(self):
|
def forward(self, checkpoint_path):
|
||||||
# checkpoint_path = "/home/tongzheng/tianshou/AlphaGo/checkpoints/"
|
# checkpoint_path = "/home/tongzheng/tianshou/AlphaGo/checkpoints/"
|
||||||
sess = multi_gpu.create_session()
|
# sess = multi_gpu.create_session()
|
||||||
sess.run(tf.global_variables_initializer())
|
# sess.run(tf.global_variables_initializer())
|
||||||
# ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
|
ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
|
||||||
# if ckpt_file is not None:
|
if ckpt_file is not None:
|
||||||
# print('Restoring model from {}...'.format(ckpt_file))
|
# print('Restoring model from {}...'.format(ckpt_file))
|
||||||
# self.saver.restore(sess, ckpt_file)
|
self.saver.restore(self.sess, ckpt_file)
|
||||||
# print('Successfully loaded')
|
# print('Successfully loaded')
|
||||||
# else:
|
else:
|
||||||
# raise ValueError("No model loaded")
|
raise ValueError("No model loaded")
|
||||||
# prior, value = sess.run([tf.nn.softmax(p), v], feed_dict={x: state, is_training: False})
|
# prior, value = sess.run([tf.nn.softmax(p), v], feed_dict={x: state, is_training: False})
|
||||||
# return prior, value
|
# return prior, value
|
||||||
return sess
|
return self.sess
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
state = np.random.randint(0, 1, [256, 9, 9, 17])
|
# state = np.random.randint(0, 1, [256, 9, 9, 17])
|
||||||
net = Network()
|
# net = Network()
|
||||||
sess = net.forward()
|
# net.train()
|
||||||
start_time = time.time()
|
# sess = net.forward()
|
||||||
for i in range(100):
|
# start_time = time.time()
|
||||||
sess.run([tf.nn.softmax(net.p), net.v], feed_dict={net.x: state, net.is_training: False})
|
# for i in range(100):
|
||||||
print("Step {}, use time {}".format(i, time.time() - start_time))
|
# sess.run([tf.nn.softmax(net.p), net.v], feed_dict={net.x: state, net.is_training: False})
|
||||||
start_time = time.time()
|
# print("Step {}, use time {}".format(i, time.time() - start_time))
|
||||||
|
# start_time = time.time()
|
||||||
|
net0 = Network()
|
||||||
|
sess0 = net0.forward("./checkpoints/")
|
||||||
|
print("Loaded")
|
||||||
|
while True:
|
||||||
|
pass
|
||||||
|
|
||||||
|
89
AlphaGo/play.py
Normal file
89
AlphaGo/play.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
pattern = "[A-Z]{1}[0-9]{1}"
|
||||||
|
size = 9
|
||||||
|
agent_v1 = subprocess.Popen(['python', '-u', 'test.py'], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||||
|
agent_v0 = subprocess.Popen(['python', '-u', 'test.py', '--checkpoint_path=./checkpoints_origin/'], stdin=subprocess.PIPE,
|
||||||
|
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||||
|
|
||||||
|
|
||||||
|
num = 0
|
||||||
|
game_num = 0
|
||||||
|
black_pass = False
|
||||||
|
white_pass = False
|
||||||
|
|
||||||
|
|
||||||
|
while game_num < 10:
|
||||||
|
print("Start game {}".format(game_num))
|
||||||
|
while not (black_pass and white_pass) and num < size ** 2 * 2:
|
||||||
|
print(num)
|
||||||
|
if num % 2 == 0:
|
||||||
|
print('BLACK TURN')
|
||||||
|
agent_v1.stdin.write(str(num) + ' genmove b\n')
|
||||||
|
agent_v1.stdin.flush()
|
||||||
|
result = agent_v1.stdout.readline()
|
||||||
|
sys.stdout.write(result)
|
||||||
|
sys.stdout.flush()
|
||||||
|
num += 1
|
||||||
|
match = re.search(pattern, result)
|
||||||
|
print("COPY BLACK")
|
||||||
|
if match is not None:
|
||||||
|
agent_v0.stdin.write(str(num) + ' play b ' + match.group() + '\n')
|
||||||
|
agent_v0.stdin.flush()
|
||||||
|
result = agent_v0.stdout.readline()
|
||||||
|
sys.stdout.flush()
|
||||||
|
else:
|
||||||
|
agent_v0.stdin.write(str(num) + ' play b PASS\n')
|
||||||
|
agent_v0.stdin.flush()
|
||||||
|
result = agent_v0.stdout.readline()
|
||||||
|
sys.stdout.flush()
|
||||||
|
if re.search("pass", result) is not None:
|
||||||
|
black_pass = True
|
||||||
|
else:
|
||||||
|
black_pass = False
|
||||||
|
else:
|
||||||
|
print('WHITE TURN')
|
||||||
|
agent_v0.stdin.write(str(num) + ' genmove w\n')
|
||||||
|
agent_v0.stdin.flush()
|
||||||
|
result = agent_v0.stdout.readline()
|
||||||
|
sys.stdout.write(result)
|
||||||
|
sys.stdout.flush()
|
||||||
|
num += 1
|
||||||
|
match = re.search(pattern, result)
|
||||||
|
print("COPY WHITE")
|
||||||
|
if match is not None:
|
||||||
|
agent_v1.stdin.write(str(num) + ' play w ' + match.group() + '\n')
|
||||||
|
agent_v1.stdin.flush()
|
||||||
|
result = agent_v1.stdout.readline()
|
||||||
|
sys.stdout.flush()
|
||||||
|
else:
|
||||||
|
agent_v1.stdin.write(str(num) + ' play w PASS\n')
|
||||||
|
agent_v1.stdin.flush()
|
||||||
|
result = agent_v1.stdout.readline()
|
||||||
|
sys.stdout.flush()
|
||||||
|
if re.search("pass", result) is not None:
|
||||||
|
black_pass = True
|
||||||
|
else:
|
||||||
|
black_pass = False
|
||||||
|
|
||||||
|
print("Finished")
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
agent_v1.stdin.write('clear_board\n')
|
||||||
|
agent_v1.stdin.flush()
|
||||||
|
result = agent_v1.stdout.readline()
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
agent_v0.stdin.write('clear_board\n')
|
||||||
|
agent_v0.stdin.flush()
|
||||||
|
result = agent_v0.stdout.readline()
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
agent_v1.stdin.write('get_score\n')
|
||||||
|
agent_v1.stdin.flush()
|
||||||
|
result = agent_v1.stdout.readline()
|
||||||
|
sys.stdout.write(result)
|
||||||
|
sys.stdout.flush()
|
||||||
|
game_num += 1
|
@ -224,10 +224,10 @@ class GoEnv:
|
|||||||
|
|
||||||
|
|
||||||
class strategy(object):
|
class strategy(object):
|
||||||
def __init__(self):
|
def __init__(self, checkpoint_path):
|
||||||
self.simulator = GoEnv()
|
self.simulator = GoEnv()
|
||||||
self.net = network_small.Network()
|
self.net = network_small.Network()
|
||||||
self.sess = self.net.forward()
|
self.sess = self.net.forward(checkpoint_path)
|
||||||
self.evaluator = lambda state: self.sess.run([tf.nn.softmax(self.net.p), self.net.v],
|
self.evaluator = lambda state: self.sess.run([tf.nn.softmax(self.net.p), self.net.v],
|
||||||
feed_dict={self.net.x: state, self.net.is_training: False})
|
feed_dict={self.net.x: state, self.net.is_training: False})
|
||||||
|
|
||||||
|
@ -2,13 +2,18 @@ import sys
|
|||||||
from game import Game
|
from game import Game
|
||||||
from engine import GTPEngine
|
from engine import GTPEngine
|
||||||
# import utils
|
# import utils
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
|
||||||
game = Game()
|
parser = argparse.ArgumentParser()
|
||||||
engine = GTPEngine(game_obj=game, name='tianshou')
|
parser.add_argument("--checkpoint_path", type=str, default="./checkpoints/")
|
||||||
cmd = raw_input
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
game = Game(checkpoint_path=args.checkpoint_path)
|
||||||
|
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
|
||||||
|
|
||||||
while not engine.disconnect:
|
while not engine.disconnect:
|
||||||
command = cmd()
|
command = sys.stdin.readline()
|
||||||
result = engine.run_cmd(command)
|
result = engine.run_cmd(command)
|
||||||
sys.stdout.write(result)
|
sys.stdout.write(result)
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user