diff --git a/AlphaGo/network_small.py b/AlphaGo/network_small.py index 2542ec4..cfff6f3 100644 --- a/AlphaGo/network_small.py +++ b/AlphaGo/network_small.py @@ -192,13 +192,16 @@ class Network(object): # checkpoint_path = "/home/tongzheng/tianshou/AlphaGo/checkpoints/" # sess = multi_gpu.create_session() # sess.run(tf.global_variables_initializer()) - ckpt_file = tf.train.latest_checkpoint(checkpoint_path) - if ckpt_file is not None: - # print('Restoring model from {}...'.format(ckpt_file)) - self.saver.restore(self.sess, ckpt_file) - # print('Successfully loaded') + if checkpoint_path is None: + self.sess.run(tf.global_variables_initializer()) else: - raise ValueError("No model loaded") + ckpt_file = tf.train.latest_checkpoint(checkpoint_path) + if ckpt_file is not None: + # print('Restoring model from {}...'.format(ckpt_file)) + self.saver.restore(self.sess, ckpt_file) + # print('Successfully loaded') + else: + raise ValueError("No model loaded") # prior, value = sess.run([tf.nn.softmax(p), v], feed_dict={x: state, is_training: False}) # return prior, value return self.sess diff --git a/AlphaGo/self-play.py b/AlphaGo/self-play.py index 91c7a50..6d5650e 100644 --- a/AlphaGo/self-play.py +++ b/AlphaGo/self-play.py @@ -14,7 +14,7 @@ args = parser.parse_args() if not os.path.exists(args.result_path): os.makedirs(args.result_path) -game = Game() +game = Game(checkpoint_path="./checkpoints/") engine = GTPEngine(game_obj=game) history = deque(maxlen=8) for i in range(8):