pass the checkpoint path to the model

This commit is contained in:
rtz19970824 2017-12-26 13:17:46 +08:00
parent 76f641a0f1
commit 725fc2c04e

View File

@ -22,8 +22,8 @@ import time
class Game:
'''
Load the real game and trained weights.
TODO : Maybe merge with the engine class in future,
TODO : Maybe merge with the engine class in future,
currently leave it untouched for interacting with Go UI.
'''
def __init__(self, name=None, role=None, debug=False, checkpoint_path=None):
@ -46,7 +46,7 @@ class Game:
else:
raise ValueError(name + " is an unknown game...")
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length)
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length, checkpoint_path=checkpoint_path)
self.latest_boards = deque(maxlen=self.history_length)
for _ in range(self.history_length):
self.latest_boards.append(self.board)