diff --git a/AlphaGo/model.py b/AlphaGo/model.py index 9106828..5037173 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -99,19 +99,19 @@ class ResNet(object): self.sess = multi_gpu.create_session() self.sess.run(tf.global_variables_initializer()) if black_checkpoint_path is not None: - ckpt_file = tf.train.latest_checkpoint(black_checkpoint_path) - if ckpt_file is not None: - print('Restoring model from {}...'.format(ckpt_file)) - self.black_saver.restore(self.sess, ckpt_file) + self.black_ckpt_file = tf.train.latest_checkpoint(black_checkpoint_path) + if self.black_ckpt_file is not None: + print('Restoring model from {}...'.format(self.black_ckpt_file)) + self.black_saver.restore(self.sess, self.black_ckpt_file) print('Successfully loaded') else: raise ValueError("No model in path {}".format(black_checkpoint_path)) if white_checkpoint_path is not None: - ckpt_file = tf.train.latest_checkpoint(white_checkpoint_path) - if ckpt_file is not None: - print('Restoring model from {}...'.format(ckpt_file)) - self.white_saver.restore(self.sess, ckpt_file) + self.white_ckpt_file = tf.train.latest_checkpoint(white_checkpoint_path) + if self.white_ckpt_file is not None: + print('Restoring model from {}...'.format(self.white_ckpt_file)) + self.white_saver.restore(self.sess, self.white_ckpt_file) print('Successfully loaded') else: raise ValueError("No model in path {}".format(white_checkpoint_path)) @@ -124,6 +124,9 @@ class ResNet(object): self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length), 'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)} + # training or not + self.training = False + def _build_network(self, scope, residual_block_num): """ build the network @@ -184,6 +187,22 @@ class ResNet(object): return self.sess.run([self.white_prob, self.white_v], feed_dict={self.x: eval_state, self.is_training: False}) + def check_latest_model(self): + if self.training: + black_ckpt_file = tf.train.latest_checkpoint(self.save_path + "black/") + if self.black_ckpt_file != black_ckpt_file: + self.black_ckpt_file = black_ckpt_file + print('Loading model from {}...'.format(self.black_ckpt_file)) + self.black_saver.restore(self.sess, self.black_ckpt_file) + print('Black Model Updated!') + + white_ckpt_file = tf.train.latest_checkpoint(self.save_path + "white/") + if self.white_ckpt_file != white_ckpt_file: + self.white_ckpt_file = white_ckpt_file + print('Loading model from {}...'.format(self.white_ckpt_file)) + self.white_saver.restore(self.sess, self.white_ckpt_file) + print('White Model Updated!') + def _history2state(self, history, color): """ convert the history to the state we need @@ -215,6 +234,7 @@ class ResNet(object): :param target: a string, which to optimize, can only be "both", "black" and "white" :param mode: a string, how to optimize, can only be "memory" and "file" """ + self.training = True if mode == 'memory': pass if mode == 'file': diff --git a/AlphaGo/play.py b/AlphaGo/play.py index 5977f06..a000b78 100644 --- a/AlphaGo/play.py +++ b/AlphaGo/play.py @@ -61,6 +61,7 @@ if __name__ == '__main__': while True: #while game_num < evaluate_rounds: start_time = time.time() + game.model.check_latest_model() num = 0 pass_flag = [False, False] print("Start game {}".format(game_num))