Merge pull request #3 from sproblvem/double-network

Double network
This commit is contained in:
sproblvem 2018-01-11 10:55:12 +08:00 committed by GitHub
commit 284cc64c18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 141 additions and 187 deletions

View File

@ -12,6 +12,7 @@ import numpy as np
import sys, os import sys, os
import model import model
from collections import deque from collections import deque
sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir)) sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir))
from tianshou.core.mcts.mcts import MCTS from tianshou.core.mcts.mcts import MCTS
@ -19,6 +20,7 @@ import go
import reversi import reversi
import time import time
class Game: class Game:
''' '''
Load the real game and trained weights. Load the real game and trained weights.
@ -26,11 +28,9 @@ class Game:
TODO : Maybe merge with the engine class in future, TODO : Maybe merge with the engine class in future,
currently leave it untouched for interacting with Go UI. currently leave it untouched for interacting with Go UI.
''' '''
def __init__(self, name=None, role=None, debug=False, checkpoint_path=None):
def __init__(self, name=None, debug=False, black_checkpoint_path=None, white_checkpoint_path=None):
self.name = name self.name = name
if role is None:
raise ValueError("Need a role!")
self.role = role
self.debug = debug self.debug = debug
if self.name == "go": if self.name == "go":
self.size = 9 self.size = 9
@ -38,7 +38,7 @@ class Game:
self.history_length = 8 self.history_length = 8
self.history = [] self.history = []
self.history_set = set() self.history_set = set()
self.game_engine = go.Go(size=self.size, komi=self.komi, role=self.role) self.game_engine = go.Go(size=self.size, komi=self.komi)
self.board = [utils.EMPTY] * (self.size ** 2) self.board = [utils.EMPTY] * (self.size ** 2)
elif self.name == "reversi": elif self.name == "reversi":
self.size = 8 self.size = 8
@ -49,8 +49,9 @@ class Game:
else: else:
raise ValueError(name + " is an unknown game...") raise ValueError(name + " is an unknown game...")
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length, self.model = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length,
checkpoint_path=checkpoint_path) black_checkpoint_path=black_checkpoint_path,
white_checkpoint_path=white_checkpoint_path)
self.latest_boards = deque(maxlen=self.history_length) self.latest_boards = deque(maxlen=self.history_length)
for _ in range(self.history_length): for _ in range(self.history_length):
self.latest_boards.append(self.board) self.latest_boards.append(self.board)
@ -72,15 +73,22 @@ class Game:
self.komi = k self.komi = k
def think(self, latest_boards, color): def think(self, latest_boards, color):
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], if color == utils.BLACK:
self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True) role = 'black'
elif color == utils.WHITE:
role = 'white'
else:
raise ValueError("game.py[think] - unknown color : {}".format(color))
evaluator = lambda state:self.model(role, state)
mcts = MCTS(self.game_engine, evaluator, [latest_boards, color],
self.size ** 2 + 1, role=role, debug=self.debug, inverse=True)
mcts.search(max_step=100) mcts.search(max_step=100)
if self.debug: if self.debug:
file = open("mcts_debug.log", 'ab') file = open("mcts_debug.log", 'ab')
np.savetxt(file, mcts.root.Q, header="\n" + self.role + " Q value : ", fmt='%.4f', newline=", ") np.savetxt(file, mcts.root.Q, header="\n" + role + " Q value : ", fmt='%.4f', newline=", ")
np.savetxt(file, mcts.root.W, header="\n" + self.role + " W value : ", fmt='%.4f', newline=", ") np.savetxt(file, mcts.root.W, header="\n" + role + " W value : ", fmt='%.4f', newline=", ")
np.savetxt(file, mcts.root.N, header="\n" + self.role + " N value : ", fmt="%d", newline=", ") np.savetxt(file, mcts.root.N, header="\n" + role + " N value : ", fmt="%d", newline=", ")
np.savetxt(file, mcts.root.prior, header="\n" + self.role + " prior : ", fmt='%.4f', newline=", ") np.savetxt(file, mcts.root.prior, header="\n" + role + " prior : ", fmt='%.4f', newline=", ")
file.close() file.close()
temp = 1 temp = 1
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp) prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
@ -98,7 +106,8 @@ class Game:
if self.name == "reversi": if self.name == "reversi":
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
if self.name == "go": if self.name == "go":
res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board, color, vertex) res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board,
color, vertex)
return res return res
def think_play_move(self, color): def think_play_move(self, color):
@ -128,8 +137,8 @@ class Game:
print('') print('')
sys.stdout.flush() sys.stdout.flush()
if __name__ == "__main__": if __name__ == "__main__":
game = Game(name="reversi", role="black", checkpoint_path=None) game = Game(name="reversi", checkpoint_path=None)
game.debug = True game.debug = True
game.think_play_move(utils.BLACK) game.think_play_move(utils.BLACK)

View File

@ -18,7 +18,6 @@ class Go:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.size = kwargs['size'] self.size = kwargs['size']
self.komi = kwargs['komi'] self.komi = kwargs['komi']
self.role = kwargs['role']
def _flatten(self, vertex): def _flatten(self, vertex):
x, y = vertex x, y = vertex
@ -332,7 +331,7 @@ class Go:
if __name__ == "__main__": if __name__ == "__main__":
go = Go(size=9, komi=3.75, role = utils.BLACK) go = Go(size=9, komi=3.75)
endgame = [ endgame = [
1, 0, 1, 0, 1, 1, -1, 0, -1, 1, 0, 1, 0, 1, 1, -1, 0, -1,
1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, -1, -1, -1,

View File

@ -80,7 +80,8 @@ class Data(object):
class ResNet(object): class ResNet(object):
def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, checkpoint_path=None): def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, black_checkpoint_path=None,
white_checkpoint_path=None):
""" """
the resnet model the resnet model
@ -88,25 +89,49 @@ class ResNet(object):
:param action_num: an integer, number of unique actions at any state :param action_num: an integer, number of unique actions at any state
:param history_length: an integer, the history length to use, default is 1 :param history_length: an integer, the history length to use, default is 1
:param residual_block_num: an integer, the number of residual block, default is 20, at least 1 :param residual_block_num: an integer, the number of residual block, default is 20, at least 1
:param checkpoint_path: a string, the path to the checkpoint, default is None, :param black_checkpoint_path: a string, the path to the black checkpoint, default is None,
:param white_checkpoint_path: a string, the path to the white checkpoint, default is None,
""" """
self.board_size = board_size self.board_size = board_size
self.action_num = action_num self.action_num = action_num
self.history_length = history_length self.history_length = history_length
self.checkpoint_path = checkpoint_path self.black_checkpoint_path = black_checkpoint_path
self.white_checkpoint_path = white_checkpoint_path
self.x = tf.placeholder(tf.float32, shape=[None, self.board_size, self.board_size, 2 * self.history_length + 1]) self.x = tf.placeholder(tf.float32, shape=[None, self.board_size, self.board_size, 2 * self.history_length + 1])
self.is_training = tf.placeholder(tf.bool, shape=[]) self.is_training = tf.placeholder(tf.bool, shape=[])
self.z = tf.placeholder(tf.float32, shape=[None, 1]) self.z = tf.placeholder(tf.float32, shape=[None, 1])
self.pi = tf.placeholder(tf.float32, shape=[None, self.action_num]) self.pi = tf.placeholder(tf.float32, shape=[None, self.action_num])
self._build_network(residual_block_num, self.checkpoint_path) self._build_network('black', residual_block_num)
self._build_network('white', residual_block_num)
self.sess = multi_gpu.create_session()
self.sess.run(tf.global_variables_initializer())
if black_checkpoint_path is not None:
ckpt_file = tf.train.latest_checkpoint(black_checkpoint_path)
if ckpt_file is not None:
print('Restoring model from {}...'.format(ckpt_file))
self.black_saver.restore(self.sess, ckpt_file)
print('Successfully loaded')
else:
raise ValueError("No model in path {}".format(black_checkpoint_path))
if white_checkpoint_path is not None:
ckpt_file = tf.train.latest_checkpoint(white_checkpoint_path)
if ckpt_file is not None:
print('Restoring model from {}...'.format(ckpt_file))
self.white_saver.restore(self.sess, ckpt_file)
print('Successfully loaded')
else:
raise ValueError("No model in path {}".format(white_checkpoint_path))
self.update = [tf.assign(white_params, black_params) for black_params, white_params in
zip(self.black_var_list, self.white_var_list)]
# training hyper-parameters: # training hyper-parameters:
self.window_length = 3 self.window_length = 900
self.save_freq = 5000 self.save_freq = 5000
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length), self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)} 'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
def _build_network(self, residual_block_num, checkpoint_path): def _build_network(self, scope, residual_block_num):
""" """
build the network build the network
@ -114,39 +139,36 @@ class ResNet(object):
:param checkpoint_path: a string, the path to the checkpoint, if None, use random initialization parameter :param checkpoint_path: a string, the path to the checkpoint, if None, use random initialization parameter
:return: None :return: None
""" """
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
h = layers.conv2d(self.x, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu,
normalizer_fn=layers.batch_norm,
normalizer_params={'is_training': self.is_training,
'updates_collections': tf.GraphKeys.UPDATE_OPS},
weights_regularizer=layers.l2_regularizer(1e-4))
for i in range(residual_block_num - 1):
h = residual_block(h, self.is_training)
self.__setattr__(scope + '_v', value_head(h, self.is_training))
self.__setattr__(scope + '_p', policy_head(h, self.is_training, self.action_num))
self.__setattr__(scope + '_prob', tf.nn.softmax(self.__getattribute__(scope + '_p')))
self.__setattr__(scope + '_value_loss', tf.reduce_mean(tf.square(self.z - self.__getattribute__(scope + '_v'))))
self.__setattr__(scope + '_policy_loss',
tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.pi,
logits=self.__getattribute__(
scope + '_p'))))
h = layers.conv2d(self.x, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu, self.__setattr__(scope + '_reg', tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=scope)))
normalizer_fn=layers.batch_norm, self.__setattr__(scope + '_total_loss', self.__getattribute__(scope + '_value_loss') + self.__getattribute__(
normalizer_params={'is_training': self.is_training, scope + '_policy_loss') + self.__getattribute__(scope + '_reg'))
'updates_collections': tf.GraphKeys.UPDATE_OPS}, self.__setattr__(scope + '_update_ops', tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=scope))
weights_regularizer=layers.l2_regularizer(1e-4)) self.__setattr__(scope + '_var_list', tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope))
for i in range(residual_block_num - 1): with tf.control_dependencies(self.__getattribute__(scope + '_update_ops')):
h = residual_block(h, self.is_training) self.__setattr__(scope + '_train_op',
self.v = value_head(h, self.is_training) tf.train.AdamOptimizer(1e-4).minimize(self.__getattribute__(scope + '_total_loss'),
self.p = policy_head(h, self.is_training, self.action_num) var_list=self.__getattribute__(scope + '_var_list')))
self.prob = tf.nn.softmax(self.p) self.__setattr__(scope + '_saver',
self.value_loss = tf.reduce_mean(tf.square(self.z - self.v)) tf.train.Saver(max_to_keep=0, var_list=self.__getattribute__(scope + '_var_list')))
self.policy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.pi, logits=self.p))
self.reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) def __call__(self, role, state):
self.total_loss = self.value_loss + self.policy_loss + self.reg
self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(self.update_ops):
self.train_op = tf.train.AdamOptimizer(1e-4).minimize(self.total_loss)
self.var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
self.saver = tf.train.Saver(max_to_keep=0, var_list=self.var_list)
self.sess = multi_gpu.create_session()
self.sess.run(tf.global_variables_initializer())
if checkpoint_path is not None:
ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
if ckpt_file is not None:
print('Restoring model from {}...'.format(ckpt_file))
self.saver.restore(self.sess, ckpt_file)
print('Successfully loaded')
else:
raise ValueError("No model in path {}".format(checkpoint_path))
def __call__(self, state):
""" """
:param history: a list, the history :param history: a list, the history
@ -154,15 +176,20 @@ class ResNet(object):
:return: a list of tensor, the predicted value and policy given the history and color :return: a list of tensor, the predicted value and policy given the history and color
""" """
# Note : maybe we can use it for isolating test of MCTS # Note : maybe we can use it for isolating test of MCTS
#prob = [1.0 / self.action_num] * self.action_num # prob = [1.0 / self.action_num] * self.action_num
#return [prob, np.random.uniform(-1, 1)] # return [prob, np.random.uniform(-1, 1)]
history, color = state history, color = state
if len(history) != self.history_length: if len(history) != self.history_length:
raise ValueError( raise ValueError(
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history), 'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
self.history_length)) self.history_length))
eval_state = self._history2state(history, color) eval_state = self._history2state(history, color)
return self.sess.run([self.prob, self.v], feed_dict={self.x: eval_state, self.is_training: False}) if role == 'black':
return self.sess.run([self.black_prob, self.black_v],
feed_dict={self.x: eval_state, self.is_training: False})
if role == 'white':
return self.sess.run([self.white_prob, self.white_v],
feed_dict={self.x: eval_state, self.is_training: False})
def _history2state(self, history, color): def _history2state(self, history, color):
""" """
@ -174,10 +201,12 @@ class ResNet(object):
""" """
state = np.zeros([1, self.board_size, self.board_size, 2 * self.history_length + 1]) state = np.zeros([1, self.board_size, self.board_size, 2 * self.history_length + 1])
for i in range(self.history_length): for i in range(self.history_length):
state[0, :, :, i] = np.array(np.array(history[i]).flatten() == np.ones(self.board_size ** 2)).reshape(self.board_size, state[0, :, :, i] = np.array(np.array(history[i]).flatten() == np.ones(self.board_size ** 2)).reshape(
self.board_size) self.board_size,
self.board_size)
state[0, :, :, i + self.history_length] = np.array( state[0, :, :, i + self.history_length] = np.array(
np.array(history[i]).flatten() == -np.ones(self.board_size ** 2)).reshape(self.board_size, self.board_size) np.array(history[i]).flatten() == -np.ones(self.board_size ** 2)).reshape(self.board_size,
self.board_size)
# TODO: need a config to specify the BLACK and WHITE # TODO: need a config to specify the BLACK and WHITE
if color == +1: if color == +1:
state[0, :, :, 2 * self.history_length] = np.ones([self.board_size, self.board_size]) state[0, :, :, 2 * self.history_length] = np.ones([self.board_size, self.board_size])
@ -187,19 +216,27 @@ class ResNet(object):
# TODO: design the interface between the environment and training # TODO: design the interface between the environment and training
def train(self, mode='memory', *args, **kwargs): def train(self, mode='memory', *args, **kwargs):
"""
The method to train the network
:param target: a string, which to optimize, can only be "both", "black" and "white"
:param mode: a string, how to optimize, can only be "memory" and "file"
"""
if mode == 'memory': if mode == 'memory':
pass pass
if mode == 'file': if mode == 'file':
self._train_with_file(data_path=kwargs['data_path'], batch_size=kwargs['batch_size'], self._train_with_file(data_path=kwargs['data_path'], batch_size=kwargs['batch_size'],
checkpoint_path=kwargs['checkpoint_path']) save_path=kwargs['save_path'])
def _train_with_file(self, data_path, batch_size, checkpoint_path): def _train_with_file(self, data_path, batch_size, save_path):
# check if the path is valid # check if the path is valid
if not os.path.exists(data_path): if not os.path.exists(data_path):
raise ValueError("{} doesn't exist".format(data_path)) raise ValueError("{} doesn't exist".format(data_path))
self.checkpoint_path = checkpoint_path self.save_path = save_path
if not os.path.exists(self.checkpoint_path): if not os.path.exists(self.save_path):
os.mkdir(self.checkpoint_path) os.mkdir(self.save_path)
os.mkdir(self.save_path + 'black')
os.mkdir(self.save_path + 'white')
new_file_list = [] new_file_list = []
all_file_list = [] all_file_list = []
@ -227,7 +264,8 @@ class ResNet(object):
else: else:
start_time = time.time() start_time = time.time()
for i in range(batch_size): for i in range(batch_size):
priority = np.array(self.training_data['length']) / (0.0 + np.sum(np.array(self.training_data['length']))) priority = np.array(self.training_data['length']) / (
0.0 + np.sum(np.array(self.training_data['length'])))
game_num = np.random.choice(self.window_length, 1, p=priority)[0] game_num = np.random.choice(self.window_length, 1, p=priority)[0]
state_num = np.random.randint(self.training_data['length'][game_num]) state_num = np.random.randint(self.training_data['length'][game_num])
rotate_times = np.random.randint(4) rotate_times = np.random.randint(4)
@ -237,11 +275,15 @@ class ResNet(object):
self._preprocession(self.training_data['states'][game_num][state_num], reflect_times, self._preprocession(self.training_data['states'][game_num][state_num], reflect_times,
reflect_orientation, rotate_times)) reflect_orientation, rotate_times))
training_data['probs'].append(np.concatenate( training_data['probs'].append(np.concatenate(
[self._preprocession(self.training_data['probs'][game_num][state_num][:-1].reshape(self.board_size, self.board_size, 1), reflect_times, [self._preprocession(
reflect_orientation, rotate_times).reshape(1, self.board_size**2), self.training_data['probs'][game_num][state_num][-1].reshape(1,1)], axis=1)) self.training_data['probs'][game_num][state_num][:-1].reshape(self.board_size,
self.board_size, 1),
reflect_times,
reflect_orientation, rotate_times).reshape(1, self.board_size ** 2),
self.training_data['probs'][game_num][state_num][-1].reshape(1, 1)], axis=1))
training_data['winner'].append(self.training_data['winner'][game_num][state_num].reshape(1, 1)) training_data['winner'].append(self.training_data['winner'][game_num][state_num].reshape(1, 1))
value_loss, policy_loss, reg, _ = self.sess.run( value_loss, policy_loss, reg, _ = self.sess.run(
[self.value_loss, self.policy_loss, self.reg, self.train_op], [self.black_value_loss, self.black_policy_loss, self.black_reg, self.black_train_op],
feed_dict={self.x: np.concatenate(training_data['states'], axis=0), feed_dict={self.x: np.concatenate(training_data['states'], axis=0),
self.z: np.concatenate(training_data['winner'], axis=0), self.z: np.concatenate(training_data['winner'], axis=0),
self.pi: np.concatenate(training_data['probs'], axis=0), self.pi: np.concatenate(training_data['probs'], axis=0),
@ -252,8 +294,11 @@ class ResNet(object):
value_loss, value_loss,
policy_loss, reg)) policy_loss, reg))
if iters % self.save_freq == 0: if iters % self.save_freq == 0:
save_path = "Iteration{}.ckpt".format(iters) ckpt_file = "Iteration{}.ckpt".format(iters)
self.saver.save(self.sess, self.checkpoint_path + save_path) self.black_saver.save(self.sess, self.save_path + 'black/' + ckpt_file)
self.sess.run(self.update)
self.white_saver.save(self.sess, self.save_path + 'white/' + ckpt_file)
for key in training_data.keys(): for key in training_data.keys():
training_data[key] = [] training_data[key] = []
iters += 1 iters += 1
@ -342,5 +387,5 @@ class ResNet(object):
if __name__ == "__main__": if __name__ == "__main__":
model = ResNet(board_size=9, action_num=82, history_length=8) model = ResNet(board_size=8, action_num=65, history_length=1, black_checkpoint_path="./checkpoint/black", white_checkpoint_path="./checkpoint/white")
model.train("file", data_path="./data/", batch_size=128, checkpoint_path="./checkpoint/") model.train(mode="file", data_path="./data/", batch_size=128, save_path="./checkpoint/")

View File

@ -1,10 +1,10 @@
import argparse import argparse
import subprocess
import sys import sys
import re import re
import Pyro4
import time import time
import os import os
from game import Game
from engine import GTPEngine
import utils import utils
from time import gmtime, strftime from time import gmtime, strftime
@ -24,6 +24,7 @@ class Data(object):
def reset(self): def reset(self):
self.__init__() self.__init__()
if __name__ == '__main__': if __name__ == '__main__':
""" """
Starting two different players which load network weights to evaluate the winning ratio. Starting two different players which load network weights to evaluate the winning ratio.
@ -47,65 +48,13 @@ if __name__ == '__main__':
if args.white_weight_path is not None and (not os.path.exists(args.white_weight_path)): if args.white_weight_path is not None and (not os.path.exists(args.white_weight_path)):
raise ValueError("Can't find the network weights for white player.") raise ValueError("Can't find the network weights for white player.")
# kill the old server game = Game(name=args.game,
# kill_old_server = subprocess.Popen(['killall', 'pyro4-ns']) black_checkpoint_path=args.black_weight_path,
# print "kill the old pyro4 name server, the return code is : " + str(kill_old_server.wait()) white_checkpoint_path=args.white_weight_path,
# time.sleep(1) debug=args.debug)
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
# start a name server if no name server exists
if len(os.popen('ps aux | grep pyro4-ns | grep -v grep').readlines()) == 0:
start_new_server = subprocess.Popen(['pyro4-ns', '&'])
print("Start Name Sever : " + str(start_new_server.pid)) # + str(start_new_server.wait())
time.sleep(1)
# start two different player with different network weights.
server_list = subprocess.check_output(['pyro4-nsc', 'list'])
current_time = strftime("%Y%m%d_%H%M%S", gmtime())
black_role_name = 'black' + current_time
white_role_name = 'white' + current_time
black_player = subprocess.Popen(
['python', '-u', 'player.py', '--game=' + args.game, '--role=' + black_role_name,
'--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)],
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
bp_output = black_player.stdout.readline()
bp_message = bp_output
# '' means player.py failed to start, "Start requestLoop" means player.py start successfully
while bp_output != '' and "Start requestLoop" not in bp_output:
bp_output = black_player.stdout.readline()
bp_message += bp_output
print("============ " + black_role_name + " message ============" + "\n" + bp_message),
white_player = subprocess.Popen(
['python', '-u', 'player.py', '--game=' + args.game, '--role=' + white_role_name,
'--checkpoint_path=' + str(args.white_weight_path), '--debug=' + str(args.debug)],
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
wp_output = white_player.stdout.readline()
wp_message = wp_output
while wp_output != '' and "Start requestLoop" not in wp_output:
wp_output = white_player.stdout.readline()
wp_message += wp_output
print("============ " + white_role_name + " message ============" + "\n" + wp_message),
server_list = ""
while (black_role_name not in server_list) or (white_role_name not in server_list):
if python_version < (3, 0):
# TODO : @renyong what is the difference between those two options?
server_list = subprocess.check_output(['pyro4-nsc', 'list'])
else:
server_list = subprocess.check_output(['pyro4-nsc', 'list'])
print("Waiting for the server start...")
time.sleep(1)
print(server_list)
print("Start black player at : " + str(black_player.pid))
print("Start white player at : " + str(white_player.pid))
data = Data() data = Data()
player = [None] * 2
player[0] = Pyro4.Proxy("PYRONAME:" + black_role_name)
player[1] = Pyro4.Proxy("PYRONAME:" + white_role_name)
role = ["BLACK", "WHITE"] role = ["BLACK", "WHITE"]
color = ['b', 'w'] color = ['b', 'w']
@ -118,7 +67,7 @@ if __name__ == '__main__':
game_num = 0 game_num = 0
try: try:
while True: while True:
# while game_num < evaluate_rounds: #while game_num < evaluate_rounds:
start_time = time.time() start_time = time.time()
num = 0 num = 0
pass_flag = [False, False] pass_flag = [False, False]
@ -126,7 +75,7 @@ if __name__ == '__main__':
# end the game if both palyer chose to pass, or play too much turns # end the game if both palyer chose to pass, or play too much turns
while not (pass_flag[0] and pass_flag[1]) and num < size[args.game] ** 2 * 2: while not (pass_flag[0] and pass_flag[1]) and num < size[args.game] ** 2 * 2:
turn = num % 2 turn = num % 2
board = player[turn].run_cmd(str(num) + ' show_board') board = engine.run_cmd(str(num) + ' show_board')
board = eval(board[board.index('['):board.index(']') + 1]) board = eval(board[board.index('['):board.index(']') + 1])
for i in range(size[args.game]): for i in range(size[args.game]):
for j in range(size[args.game]): for j in range(size[args.game]):
@ -134,7 +83,7 @@ if __name__ == '__main__':
print "\n", print "\n",
data.boards.append(board) data.boards.append(board)
start_time = time.time() start_time = time.time()
move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn])[:-1] move = engine.run_cmd(str(num) + ' genmove ' + color[turn])[:-1]
print("\n" + role[turn] + " : " + str(move)), print("\n" + role[turn] + " : " + str(move)),
num += 1 num += 1
match = re.search(pattern, move) match = re.search(pattern, move)
@ -146,21 +95,19 @@ if __name__ == '__main__':
# print "no match" # print "no match"
play_or_pass = ' PASS' play_or_pass = ' PASS'
pass_flag[turn] = True pass_flag[turn] = True
result = player[1 - turn].run_cmd(str(num) + ' play ' + color[turn] + ' ' + play_or_pass + '\n') prob = engine.run_cmd(str(num) + ' get_prob')
prob = player[turn].run_cmd(str(num) + ' get_prob')
prob = space.sub(',', prob[prob.index('['):prob.index(']') + 1]) prob = space.sub(',', prob[prob.index('['):prob.index(']') + 1])
prob = prob.replace('[,', '[') prob = prob.replace('[,', '[')
prob = prob.replace('],', ']') prob = prob.replace('],', ']')
prob = eval(prob) prob = eval(prob)
data.probs.append(prob) data.probs.append(prob)
score = player[0].run_cmd(str(num) + ' get_score') score = engine.run_cmd(str(num) + ' get_score')
print("Finished : {}".format(score.split(" ")[1])) print("Finished : {}".format(score.split(" ")[1]))
if eval(score.split(" ")[1]) > 0: if eval(score.split(" ")[1]) > 0:
data.winner = utils.BLACK data.winner = utils.BLACK
if eval(score.split(" ")[1]) < 0: if eval(score.split(" ")[1]) < 0:
data.winner = utils.WHITE data.winner = utils.WHITE
player[0].run_cmd(str(num) + ' clear_board') engine.run_cmd(str(num) + ' clear_board')
player[1].run_cmd(str(num) + ' clear_board')
file_list = os.listdir(args.data_path) file_list = os.listdir(args.data_path)
current_time = strftime("%Y%m%d_%H%M%S", gmtime()) current_time = strftime("%Y%m%d_%H%M%S", gmtime())
if os.path.exists(args.data_path + current_time + ".pkl"): if os.path.exists(args.data_path + current_time + ".pkl"):
@ -172,7 +119,3 @@ if __name__ == '__main__':
game_num += 1 game_num += 1
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
subprocess.call(["kill", "-9", str(black_player.pid)])
subprocess.call(["kill", "-9", str(white_player.pid)])
print("Kill all player, finish all game.")

View File

@ -1,42 +0,0 @@
import argparse
import Pyro4
from game import Game
from engine import GTPEngine
@Pyro4.expose
class Player(object):
"""
This is the class which defines the object called by Pyro4 (Python remote object).
It passes the command to our engine, and return the result.
"""
def __init__(self, **kwargs):
self.role = kwargs['role']
self.engine = kwargs['engine']
def run_cmd(self, command):
return self.engine.run_cmd(command)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, default="None")
parser.add_argument("--role", type=str, default="unknown")
parser.add_argument("--debug", type=str, default="False")
parser.add_argument("--game", type=str, default="go")
args = parser.parse_args()
if args.checkpoint_path == 'None':
args.checkpoint_path = None
game = Game(name=args.game, role=args.role,
checkpoint_path=args.checkpoint_path,
debug=eval(args.debug))
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
daemon = Pyro4.Daemon() # make a Pyro daemon
ns = Pyro4.locateNS() # find the name server
player = Player(role=args.role, engine=engine)
print("Init " + args.role + " player finished")
uri = daemon.register(player) # register the greeting maker as a Pyro object
print("Start on name " + args.role)
ns.register(args.role, uri) # register the object with a name in the name server
print("Start requestLoop " + str(uri))
daemon.requestLoop() # start the event loop of the server to wait for calls