From 7a4c5c3c88acd4f989f14e910e8300401ceb541d Mon Sep 17 00:00:00 2001 From: rtz19970824 Date: Sun, 3 Dec 2017 19:16:21 +0800 Subject: [PATCH] minor fixed --- AlphaGo/strategy.py | 100 +++++++++++++++++++++++++++++++++++++ AlphaGo/test.py | 2 +- tianshou/core/mcts/mcts.py | 6 ++- 3 files changed, 106 insertions(+), 2 deletions(-) diff --git a/AlphaGo/strategy.py b/AlphaGo/strategy.py index 99a8e4d..91ed4e5 100644 --- a/AlphaGo/strategy.py +++ b/AlphaGo/strategy.py @@ -17,6 +17,106 @@ class GoEnv: def _flatten(self, vertex): x, y = vertex return (x - 1) * self.size + (y - 1) + + def _bfs(self, vertex, color, block, status, alive_break): + block.append(vertex) + status[self._flatten(vertex)] = True + nei = self._neighbor(vertex) + for n in nei: + if not status[self._flatten(n)]: + if self.board[self._flatten(n)] == color: + self._bfs(n, color, block, status, alive_break) + + def _find_block(self, vertex, alive_break=False): + block = [] + status = [False] * (self.size * self.size) + color = self.board[self._flatten(vertex)] + self._bfs(vertex, color, block, status, alive_break) + + for b in block: + for n in self._neighbor(b): + if self.board[self._flatten(n)] == utils.EMPTY: + return False, block + return True, block + + def _is_qi(self, color, vertex): + nei = self._neighbor(vertex) + for n in nei: + if self.board[self._flatten(n)] == utils.EMPTY: + return True + + self.board[self._flatten(vertex)] = color + for n in nei: + if self.board[self._flatten(n)] == utils.another_color(color): + can_kill, block = self._find_block(n) + if can_kill: + self.board[self._flatten(vertex)] = utils.EMPTY + return True + + ### can not suicide + can_kill, block = self._find_block(vertex) + if can_kill: + self.board[self._flatten(vertex)] = utils.EMPTY + return False + + self.board[self._flatten(vertex)] = utils.EMPTY + return True + + def _check_global_isomorphous(self, color, vertex): + ##backup + _board = copy.copy(self.board) + self.board[self._flatten(vertex)] = color + self._process_board(color, vertex) + if self.board in self.history: + res = True + else: + res = False + + self.board = _board + return res + + def _in_board(self, vertex): + x, y = vertex + if x < 1 or x > self.size: return False + if y < 1 or y > self.size: return False + return True + + def _neighbor(self, vertex): + x, y = vertex + nei = [] + for d in DELTA: + _x = x + d[0] + _y = y + d[1] + if self._in_board((_x, _y)): + nei.append((_x, _y)) + return nei + + def _process_board(self, color, vertex): + nei = self._neighbor(vertex) + for n in nei: + if self.board[self._flatten(n)] == utils.another_color(color): + can_kill, block = self._find_block(n, alive_break=True) + if can_kill: + for b in block: + self.board[self._flatten(b)] = utils.EMPTY + + def is_valid(self, color, vertex): + ### in board + if not self._in_board(vertex): + return False + + ### already have stone + if not self.board[self._flatten(vertex)] == utils.EMPTY: + return False + + ### check if it is qi + if not self._is_qi(color, vertex): + return False + + if self._check_global_isomorphous(color, vertex): + return False + + return True def do_move(self, color, vertex): if vertex == utils.PASS: diff --git a/AlphaGo/test.py b/AlphaGo/test.py index 59c5a26..6699f5a 100644 --- a/AlphaGo/test.py +++ b/AlphaGo/test.py @@ -30,7 +30,7 @@ print(res) res = e.run_cmd('6 komi 6') print(res) -res = e.run_cmd('7 play BLACK C3') +res = e.run_cmd('7 play BLACK D4') print(res) # res = e.run_cmd('play BLACK C4') diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index c4080bb..37fc2a8 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -114,7 +114,6 @@ class ActionNode: return self.parent, self.action def expansion(self, evaluator, action_num): - # TODO: Let users/evaluator give the prior if self.next_state is not None: prior, value = evaluator(self.next_state) self.children[self.next_state] = UCTNode(self, self.action, self.origin_state, action_num, prior, @@ -134,6 +133,8 @@ class MCTS: self.simulator = simulator self.evaluator = evaluator self.action_num = action_num + if method == "": + self.root = root if method == "UCT": self.root = UCTNode(None, None, root, action_num, prior, inverse) if method == "TS": @@ -142,6 +143,9 @@ class MCTS: if max_step is not None: self.step = 0 self.max_step = max_step + # TODO: Optimize the stop criteria + # else: + # self.max_step = 0 if max_time is not None: self.start_time = time.time() self.max_time = max_time