modify model.py for multi-player

This commit is contained in:
rtz19970824 2018-01-09 19:50:37 +08:00
parent 891c5b1e47
commit eb0ce95919

View File

@ -80,7 +80,8 @@ class Data(object):
class ResNet(object): class ResNet(object):
def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, checkpoint_path=None): def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, black_checkpoint_path=None,
white_checkpoint_path=None):
""" """
the resnet model the resnet model
@ -88,25 +89,49 @@ class ResNet(object):
:param action_num: an integer, number of unique actions at any state :param action_num: an integer, number of unique actions at any state
:param history_length: an integer, the history length to use, default is 1 :param history_length: an integer, the history length to use, default is 1
:param residual_block_num: an integer, the number of residual block, default is 20, at least 1 :param residual_block_num: an integer, the number of residual block, default is 20, at least 1
:param checkpoint_path: a string, the path to the checkpoint, default is None, :param black_checkpoint_path: a string, the path to the black checkpoint, default is None,
:param white_checkpoint_path: a string, the path to the white checkpoint, default is None,
""" """
self.board_size = board_size self.board_size = board_size
self.action_num = action_num self.action_num = action_num
self.history_length = history_length self.history_length = history_length
self.checkpoint_path = checkpoint_path self.black_checkpoint_path = black_checkpoint_path
self.white_checkpoint_path = white_checkpoint_path
self.x = tf.placeholder(tf.float32, shape=[None, self.board_size, self.board_size, 2 * self.history_length + 1]) self.x = tf.placeholder(tf.float32, shape=[None, self.board_size, self.board_size, 2 * self.history_length + 1])
self.is_training = tf.placeholder(tf.bool, shape=[]) self.is_training = tf.placeholder(tf.bool, shape=[])
self.z = tf.placeholder(tf.float32, shape=[None, 1]) self.z = tf.placeholder(tf.float32, shape=[None, 1])
self.pi = tf.placeholder(tf.float32, shape=[None, self.action_num]) self.pi = tf.placeholder(tf.float32, shape=[None, self.action_num])
self._build_network(residual_block_num, self.checkpoint_path) self._build_network('black', residual_block_num)
self._build_network('white', residual_block_num)
self.sess = multi_gpu.create_session()
self.sess.run(tf.global_variables_initializer())
if black_checkpoint_path is not None:
ckpt_file = tf.train.latest_checkpoint(black_checkpoint_path)
if ckpt_file is not None:
print('Restoring model from {}...'.format(ckpt_file))
self.black_saver.restore(self.sess, ckpt_file)
print('Successfully loaded')
else:
raise ValueError("No model in path {}".format(black_checkpoint_path))
if white_checkpoint_path is not None:
ckpt_file = tf.train.latest_checkpoint(white_checkpoint_path)
if ckpt_file is not None:
print('Restoring model from {}...'.format(ckpt_file))
self.white_saver.restore(self.sess, ckpt_file)
print('Successfully loaded')
else:
raise ValueError("No model in path {}".format(white_checkpoint_path))
self.update = [tf.assign(black_params, white_params) for black_params, white_params in
zip(self.black_var_list, self.white_var_list)]
# training hyper-parameters: # training hyper-parameters:
self.window_length = 3 self.window_length = 900
self.save_freq = 5000 self.save_freq = 5000
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length), self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)} 'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
def _build_network(self, residual_block_num, checkpoint_path): def _build_network(self, scope, residual_block_num):
""" """
build the network build the network
@ -114,37 +139,34 @@ class ResNet(object):
:param checkpoint_path: a string, the path to the checkpoint, if None, use random initialization parameter :param checkpoint_path: a string, the path to the checkpoint, if None, use random initialization parameter
:return: None :return: None
""" """
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
h = layers.conv2d(self.x, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu,
normalizer_fn=layers.batch_norm,
normalizer_params={'is_training': self.is_training,
'updates_collections': tf.GraphKeys.UPDATE_OPS},
weights_regularizer=layers.l2_regularizer(1e-4))
for i in range(residual_block_num - 1):
h = residual_block(h, self.is_training)
self.__setattr__(scope + '_v', value_head(h, self.is_training))
self.__setattr__(scope + '_p', policy_head(h, self.is_training, self.action_num))
self.__setattr__(scope + '_prob', tf.nn.softmax(self.__getattribute__(scope + '_p')))
self.__setattr__(scope + '_value_loss', tf.reduce_mean(tf.square(self.z - self.__getattribute__(scope + '_v'))))
self.__setattr__(scope + '_policy_loss',
tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.pi,
logits=self.__getattribute__(
scope + '_p'))))
h = layers.conv2d(self.x, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu, self.__setattr__(scope + '_reg', tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=scope)))
normalizer_fn=layers.batch_norm, self.__setattr__(scope + '_total_loss', self.__getattribute__(scope + '_value_loss') + self.__getattribute__(
normalizer_params={'is_training': self.is_training, scope + '_policy_loss') + self.__getattribute__(scope + '_reg'))
'updates_collections': tf.GraphKeys.UPDATE_OPS}, self.__setattr__(scope + '_update_ops', tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=scope))
weights_regularizer=layers.l2_regularizer(1e-4)) self.__setattr__(scope + '_var_list', tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope))
for i in range(residual_block_num - 1): with tf.control_dependencies(self.__getattribute__(scope + '_update_ops')):
h = residual_block(h, self.is_training) self.__setattr__(scope + '_train_op',
self.v = value_head(h, self.is_training) tf.train.AdamOptimizer(1e-4).minimize(self.__getattribute__(scope + '_total_loss'),
self.p = policy_head(h, self.is_training, self.action_num) var_list=self.__getattribute__(scope + '_var_list')))
self.prob = tf.nn.softmax(self.p) self.__setattr__(scope + '_saver',
self.value_loss = tf.reduce_mean(tf.square(self.z - self.v)) tf.train.Saver(max_to_keep=0, var_list=self.__getattribute__(scope + '_var_list')))
self.policy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.pi, logits=self.p))
self.reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
self.total_loss = self.value_loss + self.policy_loss + self.reg
self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(self.update_ops):
self.train_op = tf.train.AdamOptimizer(1e-4).minimize(self.total_loss)
self.var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
self.saver = tf.train.Saver(max_to_keep=0, var_list=self.var_list)
self.sess = multi_gpu.create_session()
self.sess.run(tf.global_variables_initializer())
if checkpoint_path is not None:
ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
if ckpt_file is not None:
print('Restoring model from {}...'.format(ckpt_file))
self.saver.restore(self.sess, ckpt_file)
print('Successfully loaded')
else:
raise ValueError("No model in path {}".format(checkpoint_path))
def __call__(self, state): def __call__(self, state):
""" """
@ -154,15 +176,20 @@ class ResNet(object):
:return: a list of tensor, the predicted value and policy given the history and color :return: a list of tensor, the predicted value and policy given the history and color
""" """
# Note : maybe we can use it for isolating test of MCTS # Note : maybe we can use it for isolating test of MCTS
#prob = [1.0 / self.action_num] * self.action_num # prob = [1.0 / self.action_num] * self.action_num
#return [prob, np.random.uniform(-1, 1)] # return [prob, np.random.uniform(-1, 1)]
history, color = state history, color = state
if len(history) != self.history_length: if len(history) != self.history_length:
raise ValueError( raise ValueError(
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history), 'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
self.history_length)) self.history_length))
eval_state = self._history2state(history, color) eval_state = self._history2state(history, color)
return self.sess.run([self.prob, self.v], feed_dict={self.x: eval_state, self.is_training: False}) if color == +1:
return self.sess.run([self.black_prob, self.black_v],
feed_dict={self.x: eval_state, self.is_training: False})
if color == -1:
return self.sess.run([self.white_prob, self.white_v],
feed_dict={self.x: eval_state, self.is_training: False})
def _history2state(self, history, color): def _history2state(self, history, color):
""" """
@ -174,10 +201,12 @@ class ResNet(object):
""" """
state = np.zeros([1, self.board_size, self.board_size, 2 * self.history_length + 1]) state = np.zeros([1, self.board_size, self.board_size, 2 * self.history_length + 1])
for i in range(self.history_length): for i in range(self.history_length):
state[0, :, :, i] = np.array(np.array(history[i]).flatten() == np.ones(self.board_size ** 2)).reshape(self.board_size, state[0, :, :, i] = np.array(np.array(history[i]).flatten() == np.ones(self.board_size ** 2)).reshape(
self.board_size) self.board_size,
self.board_size)
state[0, :, :, i + self.history_length] = np.array( state[0, :, :, i + self.history_length] = np.array(
np.array(history[i]).flatten() == -np.ones(self.board_size ** 2)).reshape(self.board_size, self.board_size) np.array(history[i]).flatten() == -np.ones(self.board_size ** 2)).reshape(self.board_size,
self.board_size)
# TODO: need a config to specify the BLACK and WHITE # TODO: need a config to specify the BLACK and WHITE
if color == +1: if color == +1:
state[0, :, :, 2 * self.history_length] = np.ones([self.board_size, self.board_size]) state[0, :, :, 2 * self.history_length] = np.ones([self.board_size, self.board_size])
@ -187,19 +216,27 @@ class ResNet(object):
# TODO: design the interface between the environment and training # TODO: design the interface between the environment and training
def train(self, mode='memory', *args, **kwargs): def train(self, mode='memory', *args, **kwargs):
"""
The method to train the network
:param target: a string, which to optimize, can only be "both", "black" and "white"
:param mode: a string, how to optimize, can only be "memory" and "file"
"""
if mode == 'memory': if mode == 'memory':
pass pass
if mode == 'file': if mode == 'file':
self._train_with_file(data_path=kwargs['data_path'], batch_size=kwargs['batch_size'], self._train_with_file(data_path=kwargs['data_path'], batch_size=kwargs['batch_size'],
checkpoint_path=kwargs['checkpoint_path']) save_path=kwargs['save_path'])
def _train_with_file(self, data_path, batch_size, checkpoint_path): def _train_with_file(self, data_path, batch_size, save_path):
# check if the path is valid # check if the path is valid
if not os.path.exists(data_path): if not os.path.exists(data_path):
raise ValueError("{} doesn't exist".format(data_path)) raise ValueError("{} doesn't exist".format(data_path))
self.checkpoint_path = checkpoint_path self.save_path = save_path
if not os.path.exists(self.checkpoint_path): if not os.path.exists(self.save_path):
os.mkdir(self.checkpoint_path) os.mkdir(self.save_path)
os.mkdir(self.save_path + 'black')
os.mkdir(self.save_path + 'white')
new_file_list = [] new_file_list = []
all_file_list = [] all_file_list = []
@ -227,7 +264,8 @@ class ResNet(object):
else: else:
start_time = time.time() start_time = time.time()
for i in range(batch_size): for i in range(batch_size):
priority = np.array(self.training_data['length']) / (0.0 + np.sum(np.array(self.training_data['length']))) priority = np.array(self.training_data['length']) / (
0.0 + np.sum(np.array(self.training_data['length'])))
game_num = np.random.choice(self.window_length, 1, p=priority)[0] game_num = np.random.choice(self.window_length, 1, p=priority)[0]
state_num = np.random.randint(self.training_data['length'][game_num]) state_num = np.random.randint(self.training_data['length'][game_num])
rotate_times = np.random.randint(4) rotate_times = np.random.randint(4)
@ -237,11 +275,15 @@ class ResNet(object):
self._preprocession(self.training_data['states'][game_num][state_num], reflect_times, self._preprocession(self.training_data['states'][game_num][state_num], reflect_times,
reflect_orientation, rotate_times)) reflect_orientation, rotate_times))
training_data['probs'].append(np.concatenate( training_data['probs'].append(np.concatenate(
[self._preprocession(self.training_data['probs'][game_num][state_num][:-1].reshape(self.board_size, self.board_size, 1), reflect_times, [self._preprocession(
reflect_orientation, rotate_times).reshape(1, self.board_size**2), self.training_data['probs'][game_num][state_num][-1].reshape(1,1)], axis=1)) self.training_data['probs'][game_num][state_num][:-1].reshape(self.board_size,
self.board_size, 1),
reflect_times,
reflect_orientation, rotate_times).reshape(1, self.board_size ** 2),
self.training_data['probs'][game_num][state_num][-1].reshape(1, 1)], axis=1))
training_data['winner'].append(self.training_data['winner'][game_num][state_num].reshape(1, 1)) training_data['winner'].append(self.training_data['winner'][game_num][state_num].reshape(1, 1))
value_loss, policy_loss, reg, _ = self.sess.run( value_loss, policy_loss, reg, _ = self.sess.run(
[self.value_loss, self.policy_loss, self.reg, self.train_op], [self.black_value_loss, self.black_policy_loss, self.black_reg, self.black_train_op],
feed_dict={self.x: np.concatenate(training_data['states'], axis=0), feed_dict={self.x: np.concatenate(training_data['states'], axis=0),
self.z: np.concatenate(training_data['winner'], axis=0), self.z: np.concatenate(training_data['winner'], axis=0),
self.pi: np.concatenate(training_data['probs'], axis=0), self.pi: np.concatenate(training_data['probs'], axis=0),
@ -252,8 +294,11 @@ class ResNet(object):
value_loss, value_loss,
policy_loss, reg)) policy_loss, reg))
if iters % self.save_freq == 0: if iters % self.save_freq == 0:
save_path = "Iteration{}.ckpt".format(iters) ckpt_file = "Iteration{}.ckpt".format(iters)
self.saver.save(self.sess, self.checkpoint_path + save_path) self.black_saver.save(self.sess, self.save_path + 'black/' + ckpt_file)
self.sess.run(self.update)
self.white_saver.save(self.sess, self.save_path + 'white/' + ckpt_file)
for key in training_data.keys(): for key in training_data.keys():
training_data[key] = [] training_data[key] = []
iters += 1 iters += 1
@ -342,5 +387,5 @@ class ResNet(object):
if __name__ == "__main__": if __name__ == "__main__":
model = ResNet(board_size=9, action_num=82, history_length=8) model = ResNet(board_size=8, action_num=65, history_length=1, black_checkpoint_path="./checkpoint/black", white_checkpoint_path="./checkpoint/white")
model.train("file", data_path="./data/", batch_size=128, checkpoint_path="./checkpoint/") model.train(mode="file", data_path="./data/", batch_size=128, save_path="./checkpoint/")