commit
284cc64c18
@ -12,6 +12,7 @@ import numpy as np
|
|||||||
import sys, os
|
import sys, os
|
||||||
import model
|
import model
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir))
|
sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir))
|
||||||
from tianshou.core.mcts.mcts import MCTS
|
from tianshou.core.mcts.mcts import MCTS
|
||||||
|
|
||||||
@ -19,6 +20,7 @@ import go
|
|||||||
import reversi
|
import reversi
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
class Game:
|
class Game:
|
||||||
'''
|
'''
|
||||||
Load the real game and trained weights.
|
Load the real game and trained weights.
|
||||||
@ -26,11 +28,9 @@ class Game:
|
|||||||
TODO : Maybe merge with the engine class in future,
|
TODO : Maybe merge with the engine class in future,
|
||||||
currently leave it untouched for interacting with Go UI.
|
currently leave it untouched for interacting with Go UI.
|
||||||
'''
|
'''
|
||||||
def __init__(self, name=None, role=None, debug=False, checkpoint_path=None):
|
|
||||||
|
def __init__(self, name=None, debug=False, black_checkpoint_path=None, white_checkpoint_path=None):
|
||||||
self.name = name
|
self.name = name
|
||||||
if role is None:
|
|
||||||
raise ValueError("Need a role!")
|
|
||||||
self.role = role
|
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
if self.name == "go":
|
if self.name == "go":
|
||||||
self.size = 9
|
self.size = 9
|
||||||
@ -38,7 +38,7 @@ class Game:
|
|||||||
self.history_length = 8
|
self.history_length = 8
|
||||||
self.history = []
|
self.history = []
|
||||||
self.history_set = set()
|
self.history_set = set()
|
||||||
self.game_engine = go.Go(size=self.size, komi=self.komi, role=self.role)
|
self.game_engine = go.Go(size=self.size, komi=self.komi)
|
||||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||||
elif self.name == "reversi":
|
elif self.name == "reversi":
|
||||||
self.size = 8
|
self.size = 8
|
||||||
@ -49,8 +49,9 @@ class Game:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(name + " is an unknown game...")
|
raise ValueError(name + " is an unknown game...")
|
||||||
|
|
||||||
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length,
|
self.model = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length,
|
||||||
checkpoint_path=checkpoint_path)
|
black_checkpoint_path=black_checkpoint_path,
|
||||||
|
white_checkpoint_path=white_checkpoint_path)
|
||||||
self.latest_boards = deque(maxlen=self.history_length)
|
self.latest_boards = deque(maxlen=self.history_length)
|
||||||
for _ in range(self.history_length):
|
for _ in range(self.history_length):
|
||||||
self.latest_boards.append(self.board)
|
self.latest_boards.append(self.board)
|
||||||
@ -72,15 +73,22 @@ class Game:
|
|||||||
self.komi = k
|
self.komi = k
|
||||||
|
|
||||||
def think(self, latest_boards, color):
|
def think(self, latest_boards, color):
|
||||||
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color],
|
if color == utils.BLACK:
|
||||||
self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True)
|
role = 'black'
|
||||||
|
elif color == utils.WHITE:
|
||||||
|
role = 'white'
|
||||||
|
else:
|
||||||
|
raise ValueError("game.py[think] - unknown color : {}".format(color))
|
||||||
|
evaluator = lambda state:self.model(role, state)
|
||||||
|
mcts = MCTS(self.game_engine, evaluator, [latest_boards, color],
|
||||||
|
self.size ** 2 + 1, role=role, debug=self.debug, inverse=True)
|
||||||
mcts.search(max_step=100)
|
mcts.search(max_step=100)
|
||||||
if self.debug:
|
if self.debug:
|
||||||
file = open("mcts_debug.log", 'ab')
|
file = open("mcts_debug.log", 'ab')
|
||||||
np.savetxt(file, mcts.root.Q, header="\n" + self.role + " Q value : ", fmt='%.4f', newline=", ")
|
np.savetxt(file, mcts.root.Q, header="\n" + role + " Q value : ", fmt='%.4f', newline=", ")
|
||||||
np.savetxt(file, mcts.root.W, header="\n" + self.role + " W value : ", fmt='%.4f', newline=", ")
|
np.savetxt(file, mcts.root.W, header="\n" + role + " W value : ", fmt='%.4f', newline=", ")
|
||||||
np.savetxt(file, mcts.root.N, header="\n" + self.role + " N value : ", fmt="%d", newline=", ")
|
np.savetxt(file, mcts.root.N, header="\n" + role + " N value : ", fmt="%d", newline=", ")
|
||||||
np.savetxt(file, mcts.root.prior, header="\n" + self.role + " prior : ", fmt='%.4f', newline=", ")
|
np.savetxt(file, mcts.root.prior, header="\n" + role + " prior : ", fmt='%.4f', newline=", ")
|
||||||
file.close()
|
file.close()
|
||||||
temp = 1
|
temp = 1
|
||||||
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
||||||
@ -98,7 +106,8 @@ class Game:
|
|||||||
if self.name == "reversi":
|
if self.name == "reversi":
|
||||||
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
||||||
if self.name == "go":
|
if self.name == "go":
|
||||||
res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board, color, vertex)
|
res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board,
|
||||||
|
color, vertex)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def think_play_move(self, color):
|
def think_play_move(self, color):
|
||||||
@ -128,8 +137,8 @@ class Game:
|
|||||||
print('')
|
print('')
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
game = Game(name="reversi", role="black", checkpoint_path=None)
|
game = Game(name="reversi", checkpoint_path=None)
|
||||||
game.debug = True
|
game.debug = True
|
||||||
game.think_play_move(utils.BLACK)
|
game.think_play_move(utils.BLACK)
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,6 @@ class Go:
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.size = kwargs['size']
|
self.size = kwargs['size']
|
||||||
self.komi = kwargs['komi']
|
self.komi = kwargs['komi']
|
||||||
self.role = kwargs['role']
|
|
||||||
|
|
||||||
def _flatten(self, vertex):
|
def _flatten(self, vertex):
|
||||||
x, y = vertex
|
x, y = vertex
|
||||||
@ -332,7 +331,7 @@ class Go:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
go = Go(size=9, komi=3.75, role = utils.BLACK)
|
go = Go(size=9, komi=3.75)
|
||||||
endgame = [
|
endgame = [
|
||||||
1, 0, 1, 0, 1, 1, -1, 0, -1,
|
1, 0, 1, 0, 1, 1, -1, 0, -1,
|
||||||
1, 1, 1, 1, 1, 1, -1, -1, -1,
|
1, 1, 1, 1, 1, 1, -1, -1, -1,
|
||||||
|
|||||||
137
AlphaGo/model.py
137
AlphaGo/model.py
@ -80,7 +80,8 @@ class Data(object):
|
|||||||
|
|
||||||
|
|
||||||
class ResNet(object):
|
class ResNet(object):
|
||||||
def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, checkpoint_path=None):
|
def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, black_checkpoint_path=None,
|
||||||
|
white_checkpoint_path=None):
|
||||||
"""
|
"""
|
||||||
the resnet model
|
the resnet model
|
||||||
|
|
||||||
@ -88,25 +89,49 @@ class ResNet(object):
|
|||||||
:param action_num: an integer, number of unique actions at any state
|
:param action_num: an integer, number of unique actions at any state
|
||||||
:param history_length: an integer, the history length to use, default is 1
|
:param history_length: an integer, the history length to use, default is 1
|
||||||
:param residual_block_num: an integer, the number of residual block, default is 20, at least 1
|
:param residual_block_num: an integer, the number of residual block, default is 20, at least 1
|
||||||
:param checkpoint_path: a string, the path to the checkpoint, default is None,
|
:param black_checkpoint_path: a string, the path to the black checkpoint, default is None,
|
||||||
|
:param white_checkpoint_path: a string, the path to the white checkpoint, default is None,
|
||||||
"""
|
"""
|
||||||
self.board_size = board_size
|
self.board_size = board_size
|
||||||
self.action_num = action_num
|
self.action_num = action_num
|
||||||
self.history_length = history_length
|
self.history_length = history_length
|
||||||
self.checkpoint_path = checkpoint_path
|
self.black_checkpoint_path = black_checkpoint_path
|
||||||
|
self.white_checkpoint_path = white_checkpoint_path
|
||||||
self.x = tf.placeholder(tf.float32, shape=[None, self.board_size, self.board_size, 2 * self.history_length + 1])
|
self.x = tf.placeholder(tf.float32, shape=[None, self.board_size, self.board_size, 2 * self.history_length + 1])
|
||||||
self.is_training = tf.placeholder(tf.bool, shape=[])
|
self.is_training = tf.placeholder(tf.bool, shape=[])
|
||||||
self.z = tf.placeholder(tf.float32, shape=[None, 1])
|
self.z = tf.placeholder(tf.float32, shape=[None, 1])
|
||||||
self.pi = tf.placeholder(tf.float32, shape=[None, self.action_num])
|
self.pi = tf.placeholder(tf.float32, shape=[None, self.action_num])
|
||||||
self._build_network(residual_block_num, self.checkpoint_path)
|
self._build_network('black', residual_block_num)
|
||||||
|
self._build_network('white', residual_block_num)
|
||||||
|
self.sess = multi_gpu.create_session()
|
||||||
|
self.sess.run(tf.global_variables_initializer())
|
||||||
|
if black_checkpoint_path is not None:
|
||||||
|
ckpt_file = tf.train.latest_checkpoint(black_checkpoint_path)
|
||||||
|
if ckpt_file is not None:
|
||||||
|
print('Restoring model from {}...'.format(ckpt_file))
|
||||||
|
self.black_saver.restore(self.sess, ckpt_file)
|
||||||
|
print('Successfully loaded')
|
||||||
|
else:
|
||||||
|
raise ValueError("No model in path {}".format(black_checkpoint_path))
|
||||||
|
|
||||||
|
if white_checkpoint_path is not None:
|
||||||
|
ckpt_file = tf.train.latest_checkpoint(white_checkpoint_path)
|
||||||
|
if ckpt_file is not None:
|
||||||
|
print('Restoring model from {}...'.format(ckpt_file))
|
||||||
|
self.white_saver.restore(self.sess, ckpt_file)
|
||||||
|
print('Successfully loaded')
|
||||||
|
else:
|
||||||
|
raise ValueError("No model in path {}".format(white_checkpoint_path))
|
||||||
|
self.update = [tf.assign(white_params, black_params) for black_params, white_params in
|
||||||
|
zip(self.black_var_list, self.white_var_list)]
|
||||||
|
|
||||||
# training hyper-parameters:
|
# training hyper-parameters:
|
||||||
self.window_length = 3
|
self.window_length = 900
|
||||||
self.save_freq = 5000
|
self.save_freq = 5000
|
||||||
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
|
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
|
||||||
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
|
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
|
||||||
|
|
||||||
def _build_network(self, residual_block_num, checkpoint_path):
|
def _build_network(self, scope, residual_block_num):
|
||||||
"""
|
"""
|
||||||
build the network
|
build the network
|
||||||
|
|
||||||
@ -114,7 +139,7 @@ class ResNet(object):
|
|||||||
:param checkpoint_path: a string, the path to the checkpoint, if None, use random initialization parameter
|
:param checkpoint_path: a string, the path to the checkpoint, if None, use random initialization parameter
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
|
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
|
||||||
h = layers.conv2d(self.x, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu,
|
h = layers.conv2d(self.x, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu,
|
||||||
normalizer_fn=layers.batch_norm,
|
normalizer_fn=layers.batch_norm,
|
||||||
normalizer_params={'is_training': self.is_training,
|
normalizer_params={'is_training': self.is_training,
|
||||||
@ -122,31 +147,28 @@ class ResNet(object):
|
|||||||
weights_regularizer=layers.l2_regularizer(1e-4))
|
weights_regularizer=layers.l2_regularizer(1e-4))
|
||||||
for i in range(residual_block_num - 1):
|
for i in range(residual_block_num - 1):
|
||||||
h = residual_block(h, self.is_training)
|
h = residual_block(h, self.is_training)
|
||||||
self.v = value_head(h, self.is_training)
|
self.__setattr__(scope + '_v', value_head(h, self.is_training))
|
||||||
self.p = policy_head(h, self.is_training, self.action_num)
|
self.__setattr__(scope + '_p', policy_head(h, self.is_training, self.action_num))
|
||||||
self.prob = tf.nn.softmax(self.p)
|
self.__setattr__(scope + '_prob', tf.nn.softmax(self.__getattribute__(scope + '_p')))
|
||||||
self.value_loss = tf.reduce_mean(tf.square(self.z - self.v))
|
self.__setattr__(scope + '_value_loss', tf.reduce_mean(tf.square(self.z - self.__getattribute__(scope + '_v'))))
|
||||||
self.policy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.pi, logits=self.p))
|
self.__setattr__(scope + '_policy_loss',
|
||||||
|
tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.pi,
|
||||||
|
logits=self.__getattribute__(
|
||||||
|
scope + '_p'))))
|
||||||
|
|
||||||
self.reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
|
self.__setattr__(scope + '_reg', tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=scope)))
|
||||||
self.total_loss = self.value_loss + self.policy_loss + self.reg
|
self.__setattr__(scope + '_total_loss', self.__getattribute__(scope + '_value_loss') + self.__getattribute__(
|
||||||
self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
scope + '_policy_loss') + self.__getattribute__(scope + '_reg'))
|
||||||
with tf.control_dependencies(self.update_ops):
|
self.__setattr__(scope + '_update_ops', tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=scope))
|
||||||
self.train_op = tf.train.AdamOptimizer(1e-4).minimize(self.total_loss)
|
self.__setattr__(scope + '_var_list', tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope))
|
||||||
self.var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
with tf.control_dependencies(self.__getattribute__(scope + '_update_ops')):
|
||||||
self.saver = tf.train.Saver(max_to_keep=0, var_list=self.var_list)
|
self.__setattr__(scope + '_train_op',
|
||||||
self.sess = multi_gpu.create_session()
|
tf.train.AdamOptimizer(1e-4).minimize(self.__getattribute__(scope + '_total_loss'),
|
||||||
self.sess.run(tf.global_variables_initializer())
|
var_list=self.__getattribute__(scope + '_var_list')))
|
||||||
if checkpoint_path is not None:
|
self.__setattr__(scope + '_saver',
|
||||||
ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
|
tf.train.Saver(max_to_keep=0, var_list=self.__getattribute__(scope + '_var_list')))
|
||||||
if ckpt_file is not None:
|
|
||||||
print('Restoring model from {}...'.format(ckpt_file))
|
|
||||||
self.saver.restore(self.sess, ckpt_file)
|
|
||||||
print('Successfully loaded')
|
|
||||||
else:
|
|
||||||
raise ValueError("No model in path {}".format(checkpoint_path))
|
|
||||||
|
|
||||||
def __call__(self, state):
|
def __call__(self, role, state):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param history: a list, the history
|
:param history: a list, the history
|
||||||
@ -162,7 +184,12 @@ class ResNet(object):
|
|||||||
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
|
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
|
||||||
self.history_length))
|
self.history_length))
|
||||||
eval_state = self._history2state(history, color)
|
eval_state = self._history2state(history, color)
|
||||||
return self.sess.run([self.prob, self.v], feed_dict={self.x: eval_state, self.is_training: False})
|
if role == 'black':
|
||||||
|
return self.sess.run([self.black_prob, self.black_v],
|
||||||
|
feed_dict={self.x: eval_state, self.is_training: False})
|
||||||
|
if role == 'white':
|
||||||
|
return self.sess.run([self.white_prob, self.white_v],
|
||||||
|
feed_dict={self.x: eval_state, self.is_training: False})
|
||||||
|
|
||||||
def _history2state(self, history, color):
|
def _history2state(self, history, color):
|
||||||
"""
|
"""
|
||||||
@ -174,10 +201,12 @@ class ResNet(object):
|
|||||||
"""
|
"""
|
||||||
state = np.zeros([1, self.board_size, self.board_size, 2 * self.history_length + 1])
|
state = np.zeros([1, self.board_size, self.board_size, 2 * self.history_length + 1])
|
||||||
for i in range(self.history_length):
|
for i in range(self.history_length):
|
||||||
state[0, :, :, i] = np.array(np.array(history[i]).flatten() == np.ones(self.board_size ** 2)).reshape(self.board_size,
|
state[0, :, :, i] = np.array(np.array(history[i]).flatten() == np.ones(self.board_size ** 2)).reshape(
|
||||||
|
self.board_size,
|
||||||
self.board_size)
|
self.board_size)
|
||||||
state[0, :, :, i + self.history_length] = np.array(
|
state[0, :, :, i + self.history_length] = np.array(
|
||||||
np.array(history[i]).flatten() == -np.ones(self.board_size ** 2)).reshape(self.board_size, self.board_size)
|
np.array(history[i]).flatten() == -np.ones(self.board_size ** 2)).reshape(self.board_size,
|
||||||
|
self.board_size)
|
||||||
# TODO: need a config to specify the BLACK and WHITE
|
# TODO: need a config to specify the BLACK and WHITE
|
||||||
if color == +1:
|
if color == +1:
|
||||||
state[0, :, :, 2 * self.history_length] = np.ones([self.board_size, self.board_size])
|
state[0, :, :, 2 * self.history_length] = np.ones([self.board_size, self.board_size])
|
||||||
@ -187,19 +216,27 @@ class ResNet(object):
|
|||||||
|
|
||||||
# TODO: design the interface between the environment and training
|
# TODO: design the interface between the environment and training
|
||||||
def train(self, mode='memory', *args, **kwargs):
|
def train(self, mode='memory', *args, **kwargs):
|
||||||
|
"""
|
||||||
|
The method to train the network
|
||||||
|
|
||||||
|
:param target: a string, which to optimize, can only be "both", "black" and "white"
|
||||||
|
:param mode: a string, how to optimize, can only be "memory" and "file"
|
||||||
|
"""
|
||||||
if mode == 'memory':
|
if mode == 'memory':
|
||||||
pass
|
pass
|
||||||
if mode == 'file':
|
if mode == 'file':
|
||||||
self._train_with_file(data_path=kwargs['data_path'], batch_size=kwargs['batch_size'],
|
self._train_with_file(data_path=kwargs['data_path'], batch_size=kwargs['batch_size'],
|
||||||
checkpoint_path=kwargs['checkpoint_path'])
|
save_path=kwargs['save_path'])
|
||||||
|
|
||||||
def _train_with_file(self, data_path, batch_size, checkpoint_path):
|
def _train_with_file(self, data_path, batch_size, save_path):
|
||||||
# check if the path is valid
|
# check if the path is valid
|
||||||
if not os.path.exists(data_path):
|
if not os.path.exists(data_path):
|
||||||
raise ValueError("{} doesn't exist".format(data_path))
|
raise ValueError("{} doesn't exist".format(data_path))
|
||||||
self.checkpoint_path = checkpoint_path
|
self.save_path = save_path
|
||||||
if not os.path.exists(self.checkpoint_path):
|
if not os.path.exists(self.save_path):
|
||||||
os.mkdir(self.checkpoint_path)
|
os.mkdir(self.save_path)
|
||||||
|
os.mkdir(self.save_path + 'black')
|
||||||
|
os.mkdir(self.save_path + 'white')
|
||||||
|
|
||||||
new_file_list = []
|
new_file_list = []
|
||||||
all_file_list = []
|
all_file_list = []
|
||||||
@ -227,7 +264,8 @@ class ResNet(object):
|
|||||||
else:
|
else:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
priority = np.array(self.training_data['length']) / (0.0 + np.sum(np.array(self.training_data['length'])))
|
priority = np.array(self.training_data['length']) / (
|
||||||
|
0.0 + np.sum(np.array(self.training_data['length'])))
|
||||||
game_num = np.random.choice(self.window_length, 1, p=priority)[0]
|
game_num = np.random.choice(self.window_length, 1, p=priority)[0]
|
||||||
state_num = np.random.randint(self.training_data['length'][game_num])
|
state_num = np.random.randint(self.training_data['length'][game_num])
|
||||||
rotate_times = np.random.randint(4)
|
rotate_times = np.random.randint(4)
|
||||||
@ -237,11 +275,15 @@ class ResNet(object):
|
|||||||
self._preprocession(self.training_data['states'][game_num][state_num], reflect_times,
|
self._preprocession(self.training_data['states'][game_num][state_num], reflect_times,
|
||||||
reflect_orientation, rotate_times))
|
reflect_orientation, rotate_times))
|
||||||
training_data['probs'].append(np.concatenate(
|
training_data['probs'].append(np.concatenate(
|
||||||
[self._preprocession(self.training_data['probs'][game_num][state_num][:-1].reshape(self.board_size, self.board_size, 1), reflect_times,
|
[self._preprocession(
|
||||||
reflect_orientation, rotate_times).reshape(1, self.board_size**2), self.training_data['probs'][game_num][state_num][-1].reshape(1,1)], axis=1))
|
self.training_data['probs'][game_num][state_num][:-1].reshape(self.board_size,
|
||||||
|
self.board_size, 1),
|
||||||
|
reflect_times,
|
||||||
|
reflect_orientation, rotate_times).reshape(1, self.board_size ** 2),
|
||||||
|
self.training_data['probs'][game_num][state_num][-1].reshape(1, 1)], axis=1))
|
||||||
training_data['winner'].append(self.training_data['winner'][game_num][state_num].reshape(1, 1))
|
training_data['winner'].append(self.training_data['winner'][game_num][state_num].reshape(1, 1))
|
||||||
value_loss, policy_loss, reg, _ = self.sess.run(
|
value_loss, policy_loss, reg, _ = self.sess.run(
|
||||||
[self.value_loss, self.policy_loss, self.reg, self.train_op],
|
[self.black_value_loss, self.black_policy_loss, self.black_reg, self.black_train_op],
|
||||||
feed_dict={self.x: np.concatenate(training_data['states'], axis=0),
|
feed_dict={self.x: np.concatenate(training_data['states'], axis=0),
|
||||||
self.z: np.concatenate(training_data['winner'], axis=0),
|
self.z: np.concatenate(training_data['winner'], axis=0),
|
||||||
self.pi: np.concatenate(training_data['probs'], axis=0),
|
self.pi: np.concatenate(training_data['probs'], axis=0),
|
||||||
@ -252,8 +294,11 @@ class ResNet(object):
|
|||||||
value_loss,
|
value_loss,
|
||||||
policy_loss, reg))
|
policy_loss, reg))
|
||||||
if iters % self.save_freq == 0:
|
if iters % self.save_freq == 0:
|
||||||
save_path = "Iteration{}.ckpt".format(iters)
|
ckpt_file = "Iteration{}.ckpt".format(iters)
|
||||||
self.saver.save(self.sess, self.checkpoint_path + save_path)
|
self.black_saver.save(self.sess, self.save_path + 'black/' + ckpt_file)
|
||||||
|
self.sess.run(self.update)
|
||||||
|
self.white_saver.save(self.sess, self.save_path + 'white/' + ckpt_file)
|
||||||
|
|
||||||
for key in training_data.keys():
|
for key in training_data.keys():
|
||||||
training_data[key] = []
|
training_data[key] = []
|
||||||
iters += 1
|
iters += 1
|
||||||
@ -342,5 +387,5 @@ class ResNet(object):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model = ResNet(board_size=9, action_num=82, history_length=8)
|
model = ResNet(board_size=8, action_num=65, history_length=1, black_checkpoint_path="./checkpoint/black", white_checkpoint_path="./checkpoint/white")
|
||||||
model.train("file", data_path="./data/", batch_size=128, checkpoint_path="./checkpoint/")
|
model.train(mode="file", data_path="./data/", batch_size=128, save_path="./checkpoint/")
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
import re
|
import re
|
||||||
import Pyro4
|
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
|
from game import Game
|
||||||
|
from engine import GTPEngine
|
||||||
import utils
|
import utils
|
||||||
from time import gmtime, strftime
|
from time import gmtime, strftime
|
||||||
|
|
||||||
@ -24,6 +24,7 @@ class Data(object):
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
self.__init__()
|
self.__init__()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
"""
|
"""
|
||||||
Starting two different players which load network weights to evaluate the winning ratio.
|
Starting two different players which load network weights to evaluate the winning ratio.
|
||||||
@ -47,65 +48,13 @@ if __name__ == '__main__':
|
|||||||
if args.white_weight_path is not None and (not os.path.exists(args.white_weight_path)):
|
if args.white_weight_path is not None and (not os.path.exists(args.white_weight_path)):
|
||||||
raise ValueError("Can't find the network weights for white player.")
|
raise ValueError("Can't find the network weights for white player.")
|
||||||
|
|
||||||
# kill the old server
|
game = Game(name=args.game,
|
||||||
# kill_old_server = subprocess.Popen(['killall', 'pyro4-ns'])
|
black_checkpoint_path=args.black_weight_path,
|
||||||
# print "kill the old pyro4 name server, the return code is : " + str(kill_old_server.wait())
|
white_checkpoint_path=args.white_weight_path,
|
||||||
# time.sleep(1)
|
debug=args.debug)
|
||||||
|
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
|
||||||
# start a name server if no name server exists
|
|
||||||
if len(os.popen('ps aux | grep pyro4-ns | grep -v grep').readlines()) == 0:
|
|
||||||
start_new_server = subprocess.Popen(['pyro4-ns', '&'])
|
|
||||||
print("Start Name Sever : " + str(start_new_server.pid)) # + str(start_new_server.wait())
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
# start two different player with different network weights.
|
|
||||||
server_list = subprocess.check_output(['pyro4-nsc', 'list'])
|
|
||||||
current_time = strftime("%Y%m%d_%H%M%S", gmtime())
|
|
||||||
|
|
||||||
black_role_name = 'black' + current_time
|
|
||||||
white_role_name = 'white' + current_time
|
|
||||||
|
|
||||||
black_player = subprocess.Popen(
|
|
||||||
['python', '-u', 'player.py', '--game=' + args.game, '--role=' + black_role_name,
|
|
||||||
'--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)],
|
|
||||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
|
||||||
bp_output = black_player.stdout.readline()
|
|
||||||
bp_message = bp_output
|
|
||||||
# '' means player.py failed to start, "Start requestLoop" means player.py start successfully
|
|
||||||
while bp_output != '' and "Start requestLoop" not in bp_output:
|
|
||||||
bp_output = black_player.stdout.readline()
|
|
||||||
bp_message += bp_output
|
|
||||||
print("============ " + black_role_name + " message ============" + "\n" + bp_message),
|
|
||||||
|
|
||||||
white_player = subprocess.Popen(
|
|
||||||
['python', '-u', 'player.py', '--game=' + args.game, '--role=' + white_role_name,
|
|
||||||
'--checkpoint_path=' + str(args.white_weight_path), '--debug=' + str(args.debug)],
|
|
||||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
|
||||||
wp_output = white_player.stdout.readline()
|
|
||||||
wp_message = wp_output
|
|
||||||
while wp_output != '' and "Start requestLoop" not in wp_output:
|
|
||||||
wp_output = white_player.stdout.readline()
|
|
||||||
wp_message += wp_output
|
|
||||||
print("============ " + white_role_name + " message ============" + "\n" + wp_message),
|
|
||||||
|
|
||||||
server_list = ""
|
|
||||||
while (black_role_name not in server_list) or (white_role_name not in server_list):
|
|
||||||
if python_version < (3, 0):
|
|
||||||
# TODO : @renyong what is the difference between those two options?
|
|
||||||
server_list = subprocess.check_output(['pyro4-nsc', 'list'])
|
|
||||||
else:
|
|
||||||
server_list = subprocess.check_output(['pyro4-nsc', 'list'])
|
|
||||||
print("Waiting for the server start...")
|
|
||||||
time.sleep(1)
|
|
||||||
print(server_list)
|
|
||||||
print("Start black player at : " + str(black_player.pid))
|
|
||||||
print("Start white player at : " + str(white_player.pid))
|
|
||||||
|
|
||||||
data = Data()
|
data = Data()
|
||||||
player = [None] * 2
|
|
||||||
player[0] = Pyro4.Proxy("PYRONAME:" + black_role_name)
|
|
||||||
player[1] = Pyro4.Proxy("PYRONAME:" + white_role_name)
|
|
||||||
|
|
||||||
role = ["BLACK", "WHITE"]
|
role = ["BLACK", "WHITE"]
|
||||||
color = ['b', 'w']
|
color = ['b', 'w']
|
||||||
|
|
||||||
@ -126,7 +75,7 @@ if __name__ == '__main__':
|
|||||||
# end the game if both palyer chose to pass, or play too much turns
|
# end the game if both palyer chose to pass, or play too much turns
|
||||||
while not (pass_flag[0] and pass_flag[1]) and num < size[args.game] ** 2 * 2:
|
while not (pass_flag[0] and pass_flag[1]) and num < size[args.game] ** 2 * 2:
|
||||||
turn = num % 2
|
turn = num % 2
|
||||||
board = player[turn].run_cmd(str(num) + ' show_board')
|
board = engine.run_cmd(str(num) + ' show_board')
|
||||||
board = eval(board[board.index('['):board.index(']') + 1])
|
board = eval(board[board.index('['):board.index(']') + 1])
|
||||||
for i in range(size[args.game]):
|
for i in range(size[args.game]):
|
||||||
for j in range(size[args.game]):
|
for j in range(size[args.game]):
|
||||||
@ -134,7 +83,7 @@ if __name__ == '__main__':
|
|||||||
print "\n",
|
print "\n",
|
||||||
data.boards.append(board)
|
data.boards.append(board)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn])[:-1]
|
move = engine.run_cmd(str(num) + ' genmove ' + color[turn])[:-1]
|
||||||
print("\n" + role[turn] + " : " + str(move)),
|
print("\n" + role[turn] + " : " + str(move)),
|
||||||
num += 1
|
num += 1
|
||||||
match = re.search(pattern, move)
|
match = re.search(pattern, move)
|
||||||
@ -146,21 +95,19 @@ if __name__ == '__main__':
|
|||||||
# print "no match"
|
# print "no match"
|
||||||
play_or_pass = ' PASS'
|
play_or_pass = ' PASS'
|
||||||
pass_flag[turn] = True
|
pass_flag[turn] = True
|
||||||
result = player[1 - turn].run_cmd(str(num) + ' play ' + color[turn] + ' ' + play_or_pass + '\n')
|
prob = engine.run_cmd(str(num) + ' get_prob')
|
||||||
prob = player[turn].run_cmd(str(num) + ' get_prob')
|
|
||||||
prob = space.sub(',', prob[prob.index('['):prob.index(']') + 1])
|
prob = space.sub(',', prob[prob.index('['):prob.index(']') + 1])
|
||||||
prob = prob.replace('[,', '[')
|
prob = prob.replace('[,', '[')
|
||||||
prob = prob.replace('],', ']')
|
prob = prob.replace('],', ']')
|
||||||
prob = eval(prob)
|
prob = eval(prob)
|
||||||
data.probs.append(prob)
|
data.probs.append(prob)
|
||||||
score = player[0].run_cmd(str(num) + ' get_score')
|
score = engine.run_cmd(str(num) + ' get_score')
|
||||||
print("Finished : {}".format(score.split(" ")[1]))
|
print("Finished : {}".format(score.split(" ")[1]))
|
||||||
if eval(score.split(" ")[1]) > 0:
|
if eval(score.split(" ")[1]) > 0:
|
||||||
data.winner = utils.BLACK
|
data.winner = utils.BLACK
|
||||||
if eval(score.split(" ")[1]) < 0:
|
if eval(score.split(" ")[1]) < 0:
|
||||||
data.winner = utils.WHITE
|
data.winner = utils.WHITE
|
||||||
player[0].run_cmd(str(num) + ' clear_board')
|
engine.run_cmd(str(num) + ' clear_board')
|
||||||
player[1].run_cmd(str(num) + ' clear_board')
|
|
||||||
file_list = os.listdir(args.data_path)
|
file_list = os.listdir(args.data_path)
|
||||||
current_time = strftime("%Y%m%d_%H%M%S", gmtime())
|
current_time = strftime("%Y%m%d_%H%M%S", gmtime())
|
||||||
if os.path.exists(args.data_path + current_time + ".pkl"):
|
if os.path.exists(args.data_path + current_time + ".pkl"):
|
||||||
@ -172,7 +119,3 @@ if __name__ == '__main__':
|
|||||||
game_num += 1
|
game_num += 1
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
subprocess.call(["kill", "-9", str(black_player.pid)])
|
|
||||||
subprocess.call(["kill", "-9", str(white_player.pid)])
|
|
||||||
print("Kill all player, finish all game.")
|
|
||||||
|
|||||||
@ -1,42 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import Pyro4
|
|
||||||
from game import Game
|
|
||||||
from engine import GTPEngine
|
|
||||||
|
|
||||||
@Pyro4.expose
|
|
||||||
class Player(object):
|
|
||||||
"""
|
|
||||||
This is the class which defines the object called by Pyro4 (Python remote object).
|
|
||||||
It passes the command to our engine, and return the result.
|
|
||||||
"""
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self.role = kwargs['role']
|
|
||||||
self.engine = kwargs['engine']
|
|
||||||
|
|
||||||
def run_cmd(self, command):
|
|
||||||
return self.engine.run_cmd(command)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--checkpoint_path", type=str, default="None")
|
|
||||||
parser.add_argument("--role", type=str, default="unknown")
|
|
||||||
parser.add_argument("--debug", type=str, default="False")
|
|
||||||
parser.add_argument("--game", type=str, default="go")
|
|
||||||
args = parser.parse_args()
|
|
||||||
if args.checkpoint_path == 'None':
|
|
||||||
args.checkpoint_path = None
|
|
||||||
game = Game(name=args.game, role=args.role,
|
|
||||||
checkpoint_path=args.checkpoint_path,
|
|
||||||
debug=eval(args.debug))
|
|
||||||
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
|
|
||||||
|
|
||||||
daemon = Pyro4.Daemon() # make a Pyro daemon
|
|
||||||
ns = Pyro4.locateNS() # find the name server
|
|
||||||
player = Player(role=args.role, engine=engine)
|
|
||||||
print("Init " + args.role + " player finished")
|
|
||||||
uri = daemon.register(player) # register the greeting maker as a Pyro object
|
|
||||||
print("Start on name " + args.role)
|
|
||||||
ns.register(args.role, uri) # register the object with a name in the name server
|
|
||||||
print("Start requestLoop " + str(uri))
|
|
||||||
daemon.requestLoop() # start the event loop of the server to wait for calls
|
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user