diff --git a/AlphaGo/game.py b/AlphaGo/game.py index 82cf254..60e09f0 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -35,6 +35,7 @@ class Game: self.komi = 3.75 self.history_length = 8 self.history = [] + self.history_set = set() self.game_engine = go.Go(size=self.size, komi=self.komi, role=self.role) self.board = [utils.EMPTY] * (self.size ** 2) elif self.name == "reversi": @@ -92,7 +93,10 @@ class Game: # this function can be called directly to play the opponent's move if vertex == utils.PASS: return True - res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) + if self.name == "reversi": + res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) + if self.name == "go": + res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board, color, vertex) return res def think_play_move(self, color): @@ -124,6 +128,6 @@ class Game: if __name__ == "__main__": game = Game(name="reversi", checkpoint_path=None) - game.debug = True + game.debug = False game.think_play_move(utils.BLACK) diff --git a/AlphaGo/go.py b/AlphaGo/go.py index 987fe93..cf6b7aa 100644 --- a/AlphaGo/go.py +++ b/AlphaGo/go.py @@ -97,12 +97,12 @@ class Go: for b in group: current_board[self._flatten(b)] = utils.EMPTY - def _check_global_isomorphous(self, history_boards, current_board, color, vertex): + def _check_global_isomorphous(self, history_boards_set, current_board, color, vertex): repeat = False next_board = copy.deepcopy(current_board) next_board[self._flatten(vertex)] = color self._process_board(next_board, color, vertex) - if next_board in history_boards: + if hash(tuple(next_board)) in history_boards_set: repeat = True return repeat @@ -158,7 +158,7 @@ class Go: vertex = self._deflatten(action) return vertex - def _rule_check(self, history_boards, current_board, color, vertex): + def _rule_check(self, history_boards_set, current_board, color, vertex): ### in board if not self._in_board(vertex): return False @@ -172,7 +172,7 @@ class Go: return False ### forbid global isomorphous - if self._check_global_isomorphous(history_boards, current_board, color, vertex): + if self._check_global_isomorphous(history_boards_set, current_board, color, vertex): return False return True @@ -226,13 +226,14 @@ class Go: # since go is MDP, we only need the last board for hashing return tuple(state[0][-1]) - def executor_do_move(self, history, latest_boards, current_board, color, vertex): - if not self._rule_check(history, current_board, color, vertex): + def executor_do_move(self, history, history_set, latest_boards, current_board, color, vertex): + if not self._rule_check(history_set, current_board, color, vertex): return False current_board[self._flatten(vertex)] = color self._process_board(current_board, color, vertex) history.append(copy.deepcopy(current_board)) latest_boards.append(copy.deepcopy(current_board)) + history_set.add(hash(tuple(current_board))) return True def _find_empty(self, current_board):