check the latest checkpoint while self play
This commit is contained in:
parent
c217aa165d
commit
90ffdcbb1f
@ -99,19 +99,19 @@ class ResNet(object):
|
|||||||
self.sess = multi_gpu.create_session()
|
self.sess = multi_gpu.create_session()
|
||||||
self.sess.run(tf.global_variables_initializer())
|
self.sess.run(tf.global_variables_initializer())
|
||||||
if black_checkpoint_path is not None:
|
if black_checkpoint_path is not None:
|
||||||
ckpt_file = tf.train.latest_checkpoint(black_checkpoint_path)
|
self.black_ckpt_file = tf.train.latest_checkpoint(black_checkpoint_path)
|
||||||
if ckpt_file is not None:
|
if self.black_ckpt_file is not None:
|
||||||
print('Restoring model from {}...'.format(ckpt_file))
|
print('Restoring model from {}...'.format(self.black_ckpt_file))
|
||||||
self.black_saver.restore(self.sess, ckpt_file)
|
self.black_saver.restore(self.sess, self.black_ckpt_file)
|
||||||
print('Successfully loaded')
|
print('Successfully loaded')
|
||||||
else:
|
else:
|
||||||
raise ValueError("No model in path {}".format(black_checkpoint_path))
|
raise ValueError("No model in path {}".format(black_checkpoint_path))
|
||||||
|
|
||||||
if white_checkpoint_path is not None:
|
if white_checkpoint_path is not None:
|
||||||
ckpt_file = tf.train.latest_checkpoint(white_checkpoint_path)
|
self.white_ckpt_file = tf.train.latest_checkpoint(white_checkpoint_path)
|
||||||
if ckpt_file is not None:
|
if self.white_ckpt_file is not None:
|
||||||
print('Restoring model from {}...'.format(ckpt_file))
|
print('Restoring model from {}...'.format(self.white_ckpt_file))
|
||||||
self.white_saver.restore(self.sess, ckpt_file)
|
self.white_saver.restore(self.sess, self.white_ckpt_file)
|
||||||
print('Successfully loaded')
|
print('Successfully loaded')
|
||||||
else:
|
else:
|
||||||
raise ValueError("No model in path {}".format(white_checkpoint_path))
|
raise ValueError("No model in path {}".format(white_checkpoint_path))
|
||||||
@ -124,6 +124,9 @@ class ResNet(object):
|
|||||||
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
|
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
|
||||||
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
|
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
|
||||||
|
|
||||||
|
# training or not
|
||||||
|
self.training = False
|
||||||
|
|
||||||
def _build_network(self, scope, residual_block_num):
|
def _build_network(self, scope, residual_block_num):
|
||||||
"""
|
"""
|
||||||
build the network
|
build the network
|
||||||
@ -184,6 +187,22 @@ class ResNet(object):
|
|||||||
return self.sess.run([self.white_prob, self.white_v],
|
return self.sess.run([self.white_prob, self.white_v],
|
||||||
feed_dict={self.x: eval_state, self.is_training: False})
|
feed_dict={self.x: eval_state, self.is_training: False})
|
||||||
|
|
||||||
|
def check_latest_model(self):
|
||||||
|
if self.training:
|
||||||
|
black_ckpt_file = tf.train.latest_checkpoint(self.save_path + "black/")
|
||||||
|
if self.black_ckpt_file != black_ckpt_file:
|
||||||
|
self.black_ckpt_file = black_ckpt_file
|
||||||
|
print('Loading model from {}...'.format(self.black_ckpt_file))
|
||||||
|
self.black_saver.restore(self.sess, self.black_ckpt_file)
|
||||||
|
print('Black Model Updated!')
|
||||||
|
|
||||||
|
white_ckpt_file = tf.train.latest_checkpoint(self.save_path + "white/")
|
||||||
|
if self.white_ckpt_file != white_ckpt_file:
|
||||||
|
self.white_ckpt_file = white_ckpt_file
|
||||||
|
print('Loading model from {}...'.format(self.white_ckpt_file))
|
||||||
|
self.white_saver.restore(self.sess, self.white_ckpt_file)
|
||||||
|
print('White Model Updated!')
|
||||||
|
|
||||||
def _history2state(self, history, color):
|
def _history2state(self, history, color):
|
||||||
"""
|
"""
|
||||||
convert the history to the state we need
|
convert the history to the state we need
|
||||||
@ -215,6 +234,7 @@ class ResNet(object):
|
|||||||
:param target: a string, which to optimize, can only be "both", "black" and "white"
|
:param target: a string, which to optimize, can only be "both", "black" and "white"
|
||||||
:param mode: a string, how to optimize, can only be "memory" and "file"
|
:param mode: a string, how to optimize, can only be "memory" and "file"
|
||||||
"""
|
"""
|
||||||
|
self.training = True
|
||||||
if mode == 'memory':
|
if mode == 'memory':
|
||||||
pass
|
pass
|
||||||
if mode == 'file':
|
if mode == 'file':
|
||||||
|
@ -61,6 +61,7 @@ if __name__ == '__main__':
|
|||||||
while True:
|
while True:
|
||||||
#while game_num < evaluate_rounds:
|
#while game_num < evaluate_rounds:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
game.model.check_latest_model()
|
||||||
num = 0
|
num = 0
|
||||||
pass_flag = [False, False]
|
pass_flag = [False, False]
|
||||||
print("Start game {}".format(game_num))
|
print("Start game {}".format(game_num))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user