check the latest checkpoint while self play

This commit is contained in:
rtz19970824 2018-01-12 19:16:44 +08:00
parent c217aa165d
commit 90ffdcbb1f
2 changed files with 29 additions and 8 deletions

View File

@ -99,19 +99,19 @@ class ResNet(object):
self.sess = multi_gpu.create_session() self.sess = multi_gpu.create_session()
self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.global_variables_initializer())
if black_checkpoint_path is not None: if black_checkpoint_path is not None:
ckpt_file = tf.train.latest_checkpoint(black_checkpoint_path) self.black_ckpt_file = tf.train.latest_checkpoint(black_checkpoint_path)
if ckpt_file is not None: if self.black_ckpt_file is not None:
print('Restoring model from {}...'.format(ckpt_file)) print('Restoring model from {}...'.format(self.black_ckpt_file))
self.black_saver.restore(self.sess, ckpt_file) self.black_saver.restore(self.sess, self.black_ckpt_file)
print('Successfully loaded') print('Successfully loaded')
else: else:
raise ValueError("No model in path {}".format(black_checkpoint_path)) raise ValueError("No model in path {}".format(black_checkpoint_path))
if white_checkpoint_path is not None: if white_checkpoint_path is not None:
ckpt_file = tf.train.latest_checkpoint(white_checkpoint_path) self.white_ckpt_file = tf.train.latest_checkpoint(white_checkpoint_path)
if ckpt_file is not None: if self.white_ckpt_file is not None:
print('Restoring model from {}...'.format(ckpt_file)) print('Restoring model from {}...'.format(self.white_ckpt_file))
self.white_saver.restore(self.sess, ckpt_file) self.white_saver.restore(self.sess, self.white_ckpt_file)
print('Successfully loaded') print('Successfully loaded')
else: else:
raise ValueError("No model in path {}".format(white_checkpoint_path)) raise ValueError("No model in path {}".format(white_checkpoint_path))
@ -124,6 +124,9 @@ class ResNet(object):
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length), self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)} 'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
# training or not
self.training = False
def _build_network(self, scope, residual_block_num): def _build_network(self, scope, residual_block_num):
""" """
build the network build the network
@ -184,6 +187,22 @@ class ResNet(object):
return self.sess.run([self.white_prob, self.white_v], return self.sess.run([self.white_prob, self.white_v],
feed_dict={self.x: eval_state, self.is_training: False}) feed_dict={self.x: eval_state, self.is_training: False})
def check_latest_model(self):
if self.training:
black_ckpt_file = tf.train.latest_checkpoint(self.save_path + "black/")
if self.black_ckpt_file != black_ckpt_file:
self.black_ckpt_file = black_ckpt_file
print('Loading model from {}...'.format(self.black_ckpt_file))
self.black_saver.restore(self.sess, self.black_ckpt_file)
print('Black Model Updated!')
white_ckpt_file = tf.train.latest_checkpoint(self.save_path + "white/")
if self.white_ckpt_file != white_ckpt_file:
self.white_ckpt_file = white_ckpt_file
print('Loading model from {}...'.format(self.white_ckpt_file))
self.white_saver.restore(self.sess, self.white_ckpt_file)
print('White Model Updated!')
def _history2state(self, history, color): def _history2state(self, history, color):
""" """
convert the history to the state we need convert the history to the state we need
@ -215,6 +234,7 @@ class ResNet(object):
:param target: a string, which to optimize, can only be "both", "black" and "white" :param target: a string, which to optimize, can only be "both", "black" and "white"
:param mode: a string, how to optimize, can only be "memory" and "file" :param mode: a string, how to optimize, can only be "memory" and "file"
""" """
self.training = True
if mode == 'memory': if mode == 'memory':
pass pass
if mode == 'file': if mode == 'file':

View File

@ -61,6 +61,7 @@ if __name__ == '__main__':
while True: while True:
#while game_num < evaluate_rounds: #while game_num < evaluate_rounds:
start_time = time.time() start_time = time.time()
game.model.check_latest_model()
num = 0 num = 0
pass_flag = [False, False] pass_flag = [False, False]
print("Start game {}".format(game_num)) print("Start game {}".format(game_num))