diff --git a/.gitignore b/.gitignore index d697b92..8ee6691 100644 --- a/.gitignore +++ b/.gitignore @@ -4,8 +4,8 @@ leela-zero parameters *.swp *.sublime* -checkpoints -checkpoints_origin +checkpoint *.json .DS_Store data +.log diff --git a/AlphaGo/game.py b/AlphaGo/game.py index ff1faf5..90d0bf0 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -27,29 +27,30 @@ class Game: ''' def __init__(self, name="go", checkpoint_path=None): self.name = name - if "go" == name: + if self.name == "go": self.size = 9 self.komi = 3.75 self.board = [utils.EMPTY] * (self.size ** 2) self.history = [] + self.history_length = 8 self.latest_boards = deque(maxlen=8) for _ in range(8): self.latest_boards.append(self.board) - - self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8) self.game_engine = go.Go(size=self.size, komi=self.komi) - elif "reversi" == name: + elif self.name == "reversi": self.size = 8 - self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=1) + self.history_length = 1 self.game_engine = reversi.Reversi() self.board = self.game_engine.get_board() else: - print(name + " is an unknown game...") + raise ValueError(name + " is an unknown game...") + + self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length) def clear(self): self.board = [utils.EMPTY] * (self.size ** 2) self.history = [] - for _ in range(8): + for _ in range(self.history_length): self.latest_boards.append(self.board) def set_size(self, n): @@ -76,9 +77,9 @@ class Game: if vertex == utils.PASS: return True # TODO this implementation is not very elegant - if "go" == self.name: + if self.name == "go": res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) - elif "revsersi" == self.name: + elif self.name == "reversi": res = self.game_engine.executor_do_move(self.board, color, vertex) return res diff --git a/AlphaGo/player.py b/AlphaGo/player.py index 0e3daff..e848d2b 100644 --- a/AlphaGo/player.py +++ b/AlphaGo/player.py @@ -34,7 +34,7 @@ if __name__ == '__main__': daemon = Pyro4.Daemon() # make a Pyro daemon ns = Pyro4.locateNS() # find the name server - player = Player(role = args.role, engine = engine) + player = Player(role=args.role, engine=engine) print "Init " + args.role + " player finished" uri = daemon.register(player) # register the greeting maker as a Pyro object print "Start on name " + args.role