use hash table for check_global_isomorphous

This commit is contained in:
Wenbo Hu 2017-12-29 03:30:09 +08:00
parent da156ed88e
commit 63a0d32b34
2 changed files with 13 additions and 8 deletions

View File

@ -35,6 +35,7 @@ class Game:
self.komi = 3.75 self.komi = 3.75
self.history_length = 8 self.history_length = 8
self.history = [] self.history = []
self.history_set = set()
self.game_engine = go.Go(size=self.size, komi=self.komi, role=self.role) self.game_engine = go.Go(size=self.size, komi=self.komi, role=self.role)
self.board = [utils.EMPTY] * (self.size ** 2) self.board = [utils.EMPTY] * (self.size ** 2)
elif self.name == "reversi": elif self.name == "reversi":
@ -92,7 +93,10 @@ class Game:
# this function can be called directly to play the opponent's move # this function can be called directly to play the opponent's move
if vertex == utils.PASS: if vertex == utils.PASS:
return True return True
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) if self.name == "reversi":
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
if self.name == "go":
res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board, color, vertex)
return res return res
def think_play_move(self, color): def think_play_move(self, color):
@ -124,6 +128,6 @@ class Game:
if __name__ == "__main__": if __name__ == "__main__":
game = Game(name="reversi", checkpoint_path=None) game = Game(name="reversi", checkpoint_path=None)
game.debug = True game.debug = False
game.think_play_move(utils.BLACK) game.think_play_move(utils.BLACK)

View File

@ -97,12 +97,12 @@ class Go:
for b in group: for b in group:
current_board[self._flatten(b)] = utils.EMPTY current_board[self._flatten(b)] = utils.EMPTY
def _check_global_isomorphous(self, history_boards, current_board, color, vertex): def _check_global_isomorphous(self, history_boards_set, current_board, color, vertex):
repeat = False repeat = False
next_board = copy.deepcopy(current_board) next_board = copy.deepcopy(current_board)
next_board[self._flatten(vertex)] = color next_board[self._flatten(vertex)] = color
self._process_board(next_board, color, vertex) self._process_board(next_board, color, vertex)
if next_board in history_boards: if hash(tuple(next_board)) in history_boards_set:
repeat = True repeat = True
return repeat return repeat
@ -158,7 +158,7 @@ class Go:
vertex = self._deflatten(action) vertex = self._deflatten(action)
return vertex return vertex
def _rule_check(self, history_boards, current_board, color, vertex): def _rule_check(self, history_boards_set, current_board, color, vertex):
### in board ### in board
if not self._in_board(vertex): if not self._in_board(vertex):
return False return False
@ -172,7 +172,7 @@ class Go:
return False return False
### forbid global isomorphous ### forbid global isomorphous
if self._check_global_isomorphous(history_boards, current_board, color, vertex): if self._check_global_isomorphous(history_boards_set, current_board, color, vertex):
return False return False
return True return True
@ -226,13 +226,14 @@ class Go:
# since go is MDP, we only need the last board for hashing # since go is MDP, we only need the last board for hashing
return tuple(state[0][-1]) return tuple(state[0][-1])
def executor_do_move(self, history, latest_boards, current_board, color, vertex): def executor_do_move(self, history, history_set, latest_boards, current_board, color, vertex):
if not self._rule_check(history, current_board, color, vertex): if not self._rule_check(history_set, current_board, color, vertex):
return False return False
current_board[self._flatten(vertex)] = color current_board[self._flatten(vertex)] = color
self._process_board(current_board, color, vertex) self._process_board(current_board, color, vertex)
history.append(copy.deepcopy(current_board)) history.append(copy.deepcopy(current_board))
latest_boards.append(copy.deepcopy(current_board)) latest_boards.append(copy.deepcopy(current_board))
history_set.add(hash(tuple(current_board)))
return True return True
def _find_empty(self, current_board): def _find_empty(self, current_board):