diff --git a/AlphaGo/game.py b/AlphaGo/game.py index 420f8d1..428bd5e 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -61,6 +61,8 @@ class Game: del self.board[:] self.board = [utils.EMPTY] * (self.size ** 2) del self.history[:] + del self.history_set + self.history_set = set() if self.name == "reversi": self.board = self.game_engine.get_board() for _ in range(self.history_length): diff --git a/AlphaGo/go.py b/AlphaGo/go.py index 18a1a08..6ddc05a 100644 --- a/AlphaGo/go.py +++ b/AlphaGo/go.py @@ -96,12 +96,12 @@ class Go: for b in group: current_board[self._flatten(b)] = utils.EMPTY - def _check_global_isomorphous(self, history_boards_set, current_board, color, vertex): + def _check_global_isomorphous(self, history_hashtable, current_board, color, vertex): repeat = False next_board = copy.deepcopy(current_board) next_board[self._flatten(vertex)] = color self._process_board(next_board, color, vertex) - if hash(tuple(next_board)) in history_boards_set: + if hash(tuple(next_board)) in history_hashtable: repeat = True return repeat @@ -157,7 +157,7 @@ class Go: vertex = self._deflatten(action) return vertex - def _rule_check(self, history_boards_set, current_board, color, vertex): + def _rule_check(self, history_hashtable, current_board, color, vertex): ### in board if not self._in_board(vertex): return False @@ -171,17 +171,17 @@ class Go: return False ### forbid global isomorphous - if self._check_global_isomorphous(history_boards_set, current_board, color, vertex): + if self._check_global_isomorphous(history_hashtable, current_board, color, vertex): return False return True - def _is_valid(self, state, action): + def _is_valid(self, state, action, history_hashtable): history_boards, color = state vertex = self._action2vertex(action) current_board = history_boards[-1] - if not self._rule_check(history_boards, current_board, color, vertex): + if not self._rule_check(history_hashtable, current_board, color, vertex): return False if not self._knowledge_prunning(current_board, color, vertex): @@ -191,14 +191,19 @@ class Go: def simulate_get_mask(self, state, action_set): # find all the invalid actions invalid_action_mask = [] + history_boards, color = state + history_hashtable = set() + for board in history_boards: + history_hashtable.add(hash(tuple(board))) for action_candidate in action_set[:-1]: # go through all the actions excluding pass - if not self._is_valid(state, action_candidate): + if not self._is_valid(state, action_candidate, history_hashtable): invalid_action_mask.append(action_candidate) if len(invalid_action_mask) < len(action_set) - 1: invalid_action_mask.append(action_set[-1]) # forbid pass, if we have other choices # TODO: In fact we should not do this. In some extreme cases, we should permit pass. + del history_hashtable return invalid_action_mask def _do_move(self, board, color, vertex): @@ -227,6 +232,8 @@ class Go: def executor_do_move(self, history, history_set, latest_boards, current_board, color, vertex): if not self._rule_check(history_set, current_board, color, vertex): + print(current_board) + raise ValueError("!!! We have more than four ko at the same time !!!") return False current_board[self._flatten(vertex)] = color self._process_board(current_board, color, vertex) diff --git a/AlphaGo/model.py b/AlphaGo/model.py index 0f38af6..9106828 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -170,7 +170,7 @@ class ResNet(object): """ # Note : maybe we can use it for isolating test of MCTS # prob = [1.0 / self.action_num] * self.action_num - # return [prob, np.random.uniform(-1, 1)] + # return [np.array(prob), np.random.uniform(-1, 1)] history, color = state if len(history) != self.history_length: raise ValueError(