diff --git a/.gitignore b/.gitignore index d697b92..8ee6691 100644 --- a/.gitignore +++ b/.gitignore @@ -4,8 +4,8 @@ leela-zero parameters *.swp *.sublime* -checkpoints -checkpoints_origin +checkpoint *.json .DS_Store data +.log diff --git a/AlphaGo/engine.py b/AlphaGo/engine.py index 8b54470..98e5e61 100644 --- a/AlphaGo/engine.py +++ b/AlphaGo/engine.py @@ -183,7 +183,7 @@ class GTPEngine(): return 'unknown player', False def cmd_get_score(self, args, **kwargs): - return self._game.game_engine.executor_get_score(self._game.board, True), True + return self._game.game_engine.executor_get_score(self._game.board), True def cmd_show_board(self, args, **kwargs): return self._game.board, True @@ -194,4 +194,4 @@ class GTPEngine(): if __name__ == "main": game = Game() - engine = GTPEngine(game_obj=Game) + engine = GTPEngine(game_obj=game) diff --git a/AlphaGo/game.py b/AlphaGo/game.py index 8706572..90d0bf0 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -10,12 +10,14 @@ import copy import tensorflow as tf import numpy as np import sys, os -import go import model from collections import deque sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir)) from tianshou.core.mcts.mcts import MCTS +import go +import reversi + class Game: ''' Load the real game and trained weights. @@ -23,23 +25,32 @@ class Game: TODO : Maybe merge with the engine class in future, currently leave it untouched for interacting with Go UI. ''' - def __init__(self, size=9, komi=3.75, checkpoint_path=None): - self.size = size - self.komi = komi - self.board = [utils.EMPTY] * (self.size ** 2) - self.history = [] - self.latest_boards = deque(maxlen=8) - for _ in range(8): - self.latest_boards.append(self.board) - self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8) - # self.evaluator = lambda state: self.sess.run([tf.nn.softmax(self.net.p), self.net.v], - # feed_dict={self.net.x: state, self.net.is_training: False}) - self.game_engine = go.Go(size=self.size, komi=self.komi) + def __init__(self, name="go", checkpoint_path=None): + self.name = name + if self.name == "go": + self.size = 9 + self.komi = 3.75 + self.board = [utils.EMPTY] * (self.size ** 2) + self.history = [] + self.history_length = 8 + self.latest_boards = deque(maxlen=8) + for _ in range(8): + self.latest_boards.append(self.board) + self.game_engine = go.Go(size=self.size, komi=self.komi) + elif self.name == "reversi": + self.size = 8 + self.history_length = 1 + self.game_engine = reversi.Reversi() + self.board = self.game_engine.get_board() + else: + raise ValueError(name + " is an unknown game...") + + self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length) def clear(self): self.board = [utils.EMPTY] * (self.size ** 2) self.history = [] - for _ in range(8): + for _ in range(self.history_length): self.latest_boards.append(self.board) def set_size(self, n): @@ -65,7 +76,11 @@ class Game: # this function can be called directly to play the opponent's move if vertex == utils.PASS: return True - res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) + # TODO this implementation is not very elegant + if self.name == "go": + res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) + elif self.name == "reversi": + res = self.game_engine.executor_do_move(self.board, color, vertex) return res def think_play_move(self, color): diff --git a/AlphaGo/go.py b/AlphaGo/go.py index 9b7e21f..b819c08 100644 --- a/AlphaGo/go.py +++ b/AlphaGo/go.py @@ -157,7 +157,7 @@ class Go: vertex = self._deflatten(action) return vertex - def _is_valid(self, history_boards, current_board, color, vertex): + def _rule_check(self, history_boards, current_board, color, vertex): ### in board if not self._in_board(vertex): return False @@ -176,30 +176,30 @@ class Go: return True - def simulate_is_valid(self, state, action): + def _is_valid(self, state, action): history_boards, color = state vertex = self._action2vertex(action) current_board = history_boards[-1] - if not self._is_valid(history_boards, current_board, color, vertex): + if not self._rule_check(history_boards, current_board, color, vertex): return False if not self._knowledge_prunning(current_board, color, vertex): return False return True - def simulate_is_valid_list(self, state, action_set): + def simulate_get_mask(self, state, action_set): # find all the invalid actions - invalid_action_list = [] + invalid_action_mask = [] for action_candidate in action_set[:-1]: # go through all the actions excluding pass - if not self.simulate_is_valid(state, action_candidate): - invalid_action_list.append(action_candidate) - if len(invalid_action_list) < len(action_set) - 1: - invalid_action_list.append(action_set[-1]) + if not self._is_valid(state, action_candidate): + invalid_action_mask.append(action_candidate) + if len(invalid_action_mask) < len(action_set) - 1: + invalid_action_mask.append(action_set[-1]) # forbid pass, if we have other choices # TODO: In fact we should not do this. In some extreme cases, we should permit pass. - return invalid_action_list + return invalid_action_mask def _do_move(self, board, color, vertex): if vertex == utils.PASS: @@ -219,7 +219,7 @@ class Go: return [history_boards, new_color], 0 def executor_do_move(self, history, latest_boards, current_board, color, vertex): - if not self._is_valid(history, current_board, color, vertex): + if not self._rule_check(history, current_board, color, vertex): return False current_board[self._flatten(vertex)] = color self._process_board(current_board, color, vertex) @@ -280,7 +280,7 @@ class Go: elif color_estimate < 0: return utils.WHITE - def executor_get_score(self, current_board, is_unknown_estimation=False): + def executor_get_score(self, current_board): ''' is_unknown_estimation: whether use nearby stone to predict the unknown return score from BLACK perspective. @@ -294,10 +294,8 @@ class Go: _board[self._flatten(vertex)] = utils.BLACK elif boarder_color == {utils.WHITE}: _board[self._flatten(vertex)] = utils.WHITE - elif is_unknown_estimation: - _board[self._flatten(vertex)] = self._predict_from_nearby(_board, vertex) else: - _board[self._flatten(vertex)] =utils.UNKNOWN + _board[self._flatten(vertex)] = self._predict_from_nearby(_board, vertex) score = 0 for i in _board: if i == utils.BLACK: @@ -308,3 +306,42 @@ class Go: return score +if __name__ == "__main__": + ### do unit test for Go class + pure_test = [ + 0, 1, 0, 1, 0, 1, 0, 0, 0, + 1, 0, 1, 0, 1, 0, 0, 0, 0, + 0, 1, 0, 1, 0, 0, 1, 0, 0, + 0, 0, 1, 0, 0, 1, 0, 1, 0, + 0, 0, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 1, 0, 0, 0, 0, 0, 0, + 1, 0, 1, 0, 0, 1, 1, 0, 0, + 1, 1, 1, 0, 1, 0, 1, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 0, 0 + ] + + pt_qry = [(1, 1), (1, 5), (3, 3), (4, 7), (7, 2), (8, 6)] + pt_ans = [True, True, True, True, True, True] + + opponent_test = [ + 0, 1, 0, 1, 0, 1, 0,-1, 1, + 1,-1, 0,-1, 1,-1, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 1,-1, 0, 1,-1, 1, 0, 0, + 1, 0, 1, 0, 1, 0, 1, 0, 0, + -1,1, 1, 0, 1, 1, 1, 0, 0, + 0, 1,-1, 0,-1,-1,-1, 0, 0, + 1, 0, 1, 0,-1, 0,-1, 0, 0, + 0, 1, 0, 0,-1,-1,-1, 0, 0 + ] + ot_qry = [(1, 1), (1, 5), (2, 9), (5, 2), (5, 6), (8, 6), (8, 2)] + ot_ans = [False, False, False, False, False, False, True] + + go = Go(size=9, komi=3.75) + for i in range(6): + print (go._is_eye(pure_test, utils.BLACK, pt_qry[i])) + print("Test of pure eye\n") + + for i in range(7): + print (go._is_eye(opponent_test, utils.BLACK, ot_qry[i])) + print("Test of eye surrend by opponents\n") diff --git a/AlphaGo/model.py b/AlphaGo/model.py index 3cfb900..22e8626 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -1,5 +1,6 @@ import os import time +import random import sys import cPickle from collections import deque @@ -104,7 +105,7 @@ class ResNet(object): self.window_length = 7000 self.save_freq = 5000 self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length), - 'winner': deque(maxlen=self.window_length)} + 'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)} def _build_network(self, residual_block_num, checkpoint_path): """ @@ -199,15 +200,15 @@ class ResNet(object): new_file_list = [] all_file_list = [] - training_data = {} + training_data = {'states': [], 'probs': [], 'winner': []} + iters = 0 while True: new_file_list = list(set(os.listdir(data_path)).difference(all_file_list)) - if new_file_list: + while new_file_list: all_file_list = os.listdir(data_path) - new_file_list.sort( - key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir(data_path + file) else 0) - if new_file_list: + new_file_list.sort( + key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir(data_path + file) else 0) for file in new_file_list: states, probs, winner = self._file_to_training_data(data_path + file) assert states.shape[0] == probs.shape[0] @@ -215,32 +216,36 @@ class ResNet(object): self.training_data['states'].append(states) self.training_data['probs'].append(probs) self.training_data['winner'].append(winner) - if len(self.training_data['states']) == self.window_length: - training_data['states'] = np.concatenate(self.training_data['states'], axis=0) - training_data['probs'] = np.concatenate(self.training_data['probs'], axis=0) - training_data['winner'] = np.concatenate(self.training_data['winner'], axis=0) + self.training_data['length'].append(states.shape[0]) + new_file_list = list(set(os.listdir(data_path)).difference(all_file_list)) if len(self.training_data['states']) != self.window_length: continue else: - data_num = training_data['states'].shape[0] - index = np.arange(data_num) - np.random.shuffle(index) start_time = time.time() + for i in range(batch_size): + game_num = random.randint(0, self.window_length-1) + state_num = random.randint(0, self.training_data['length'][game_num]-1) + training_data['states'].append(np.expand_dims(self.training_data['states'][game_num][state_num], 0)) + training_data['probs'].append(np.expand_dims(self.training_data['probs'][game_num][state_num], 0)) + training_data['winner'].append(np.expand_dims(self.training_data['winner'][game_num][state_num], 0)) value_loss, policy_loss, reg, _ = self.sess.run( [self.value_loss, self.policy_loss, self.reg, self.train_op], - feed_dict={self.x: training_data['states'][index[:batch_size]], - self.z: training_data['winner'][index[:batch_size]], - self.pi: training_data['probs'][index[:batch_size]], + feed_dict={self.x: np.concatenate(training_data['states'], axis=0), + self.z: np.concatenate(training_data['winner'], axis=0), + self.pi: np.concatenate(training_data['probs'], axis=0), self.is_training: True}) + print("Iteration: {}, Time: {}, Value Loss: {}, Policy Loss: {}, Reg: {}".format(iters, time.time() - start_time, value_loss, policy_loss, reg)) - iters += 1 if iters % self.save_freq == 0: save_path = "Iteration{}.ckpt".format(iters) self.saver.save(self.sess, self.checkpoint_path + save_path) + for key in training_data.keys(): + training_data[key] = [] + iters += 1 def _file_to_training_data(self, file_name): read = False @@ -250,7 +255,7 @@ class ResNet(object): file.seek(0) data = cPickle.load(file) read = True - print("{} Loaded".format(file_name)) + print("{} Loaded!".format(file_name)) except Exception as e: print(e) time.sleep(1) @@ -276,6 +281,6 @@ class ResNet(object): return states, probs, winner -if __name__=="__main__": - model = ResNet(board_size=9, action_num=82) +if __name__ == "__main__": + model = ResNet(board_size=9, action_num=82, history_length=8) model.train("file", data_path="./data/", batch_size=128, checkpoint_path="./checkpoint/") diff --git a/AlphaGo/play.py b/AlphaGo/play.py index 3681430..b601ada 100644 --- a/AlphaGo/play.py +++ b/AlphaGo/play.py @@ -7,7 +7,6 @@ import time import os import cPickle - class Data(object): def __init__(self): self.boards = [] diff --git a/AlphaGo/player.py b/AlphaGo/player.py index 0e3daff..e848d2b 100644 --- a/AlphaGo/player.py +++ b/AlphaGo/player.py @@ -34,7 +34,7 @@ if __name__ == '__main__': daemon = Pyro4.Daemon() # make a Pyro daemon ns = Pyro4.locateNS() # find the name server - player = Player(role = args.role, engine = engine) + player = Player(role=args.role, engine=engine) print "Init " + args.role + " player finished" uri = daemon.register(player) # register the greeting maker as a Pyro object print "Start on name " + args.role diff --git a/AlphaGo/reversi.py b/AlphaGo/reversi.py index 49d0e9a..c086a2c 100644 --- a/AlphaGo/reversi.py +++ b/AlphaGo/reversi.py @@ -25,7 +25,6 @@ def find_correct_moves(own, enemy): mobility |= search_offset_right(own, enemy, mask, 7) # Left bottom return mobility - def calc_flip(pos, own, enemy): """return flip stones of enemy by bitboard when I place stone at pos. @@ -34,7 +33,6 @@ def calc_flip(pos, own, enemy): :param enemy: bitboard :return: flip stones of enemy when I place stone at pos. """ - assert 0 <= pos <= 63, f"pos={pos}" f1 = _calc_flip_half(pos, own, enemy) f2 = _calc_flip_half(63 - pos, rotate180(own), rotate180(enemy)) return f1 | rotate180(f2) @@ -125,27 +123,42 @@ class Reversi: self.board = None # 8 * 8 board with 1 for black, -1 for white and 0 for blank self.color = None # 1 for black and -1 for white self.action = None # number in 0~63 - self.winner = None + # self.winner = None + self.black_win = None - def simulate_is_valid(self, board, color): + def get_board(self, black=None, white=None): + self.black = black or (0b00001000 << 24 | 0b00010000 << 32) + self.white = white or (0b00010000 << 24 | 0b00001000 << 32) + self.board = self.bitboard2board() + return self.board + + def simulate_get_mask(self, state, action_set): + history_boards, color = state + board = history_boards[-1] self.board = board self.color = color self.board2bitboard() own, enemy = self.get_own_and_enemy() mobility = find_correct_moves(own, enemy) valid_moves = bit_to_array(mobility, 64) + valid_moves = np.argwhere(valid_moves) valid_moves = list(np.reshape(valid_moves, len(valid_moves))) - return valid_moves + # TODO it seems that the pass move is not considered + invalid_action_mask = [] + for action in action_set: + if action not in valid_moves: + invalid_action_mask.append(action) + return invalid_action_mask - def simulate_step_forward(self, board, color, vertex): - self.board = board - self.color = color + def simulate_step_forward(self, state, action): + self.board = state[0] + self.color = state[1] self.board2bitboard() - self.vertex2action(vertex) + self.action = action step_forward = self.step() if step_forward: new_board = self.bitboard2board() - return new_board + return [new_board, 0 - self.color], 0 def executor_do_move(self, board, color, vertex): self.board = board @@ -155,20 +168,21 @@ class Reversi: step_forward = self.step() if step_forward: new_board = self.bitboard2board() - return new_board + for i in range(64): + board[i] = new_board[i] def executor_get_score(self, board): self.board = board self._game_over() - if self.winner is not None: - return self.winner, 0 - self.winner + if self.black_win is not None: + return self.black_win else: - ValueError("Game not finished!") + raise ValueError("Game not finished!") def board2bitboard(self): count = 1 if self.board is None: - ValueError("None board!") + raise ValueError("None board!") self.black = 0 self.white = 0 for i in range(64): @@ -200,7 +214,7 @@ class Reversi: def step(self): if self.action < 0 or self.action > 63: - ValueError("Wrong action!") + raise ValueError("Wrong action!") if self.action is None: return False @@ -219,6 +233,7 @@ class Reversi: def _game_over(self): # self.done = True + ''' if self.winner is None: black_num, white_num = self.number_of_black_and_white if black_num > white_num: @@ -227,9 +242,12 @@ class Reversi: self.winner = -1 else: self.winner = 0 + ''' + if self.black_win is None: + black_num, white_num = self.number_of_black_and_white + self.black_win = black_num - white_num def illegal_move_to_lose(self, action): - logger.warning(f"Illegal action={action}, No Flipped!") self._game_over() def get_own_and_enemy(self): diff --git a/AlphaGo/self-play.py b/AlphaGo/self-play.py index 4387b24..dd03b13 100644 --- a/AlphaGo/self-play.py +++ b/AlphaGo/self-play.py @@ -79,7 +79,7 @@ while True: prob.append(np.array(game.prob).reshape(-1, game.size ** 2 + 1)) print("Finished") print("\n") - score = game.game_engine.executor_get_score(game.board, True) + score = game.game_engine.executor_get_score(game.board) if score > 0: winner = utils.BLACK else: diff --git a/AlphaGo/unit_test.py b/AlphaGo/unit_test.py deleted file mode 100644 index 7a33b8e..0000000 --- a/AlphaGo/unit_test.py +++ /dev/null @@ -1,266 +0,0 @@ -import numpy as np -import sys -from game import Game -from engine import GTPEngine -import utils -import time -import copy -import network_small -import tensorflow as tf -from collections import deque -from tianshou.core.mcts.mcts import MCTS - -DELTA = [[1, 0], [-1, 0], [0, -1], [0, 1]] -CORNER_OFFSET = [[-1, -1], [-1, 1], [1, 1], [1, -1]] - -class GoEnv: - def __init__(self, size=9, komi=6.5): - self.size = size - self.komi = komi - self.board = [utils.EMPTY] * (self.size * self.size) - self.history = deque(maxlen=8) - - def _set_board(self, board): - self.board = board - - def _flatten(self, vertex): - x, y = vertex - return (x - 1) * self.size + (y - 1) - - def _bfs(self, vertex, color, block, status, alive_break): - block.append(vertex) - status[self._flatten(vertex)] = True - nei = self._neighbor(vertex) - for n in nei: - if not status[self._flatten(n)]: - if self.board[self._flatten(n)] == color: - self._bfs(n, color, block, status, alive_break) - - def _find_block(self, vertex, alive_break=False): - block = [] - status = [False] * (self.size * self.size) - color = self.board[self._flatten(vertex)] - self._bfs(vertex, color, block, status, alive_break) - - for b in block: - for n in self._neighbor(b): - if self.board[self._flatten(n)] == utils.EMPTY: - return False, block - return True, block - - def _is_qi(self, color, vertex): - nei = self._neighbor(vertex) - for n in nei: - if self.board[self._flatten(n)] == utils.EMPTY: - return True - - self.board[self._flatten(vertex)] = color - for n in nei: - if self.board[self._flatten(n)] == utils.another_color(color): - can_kill, block = self._find_block(n) - if can_kill: - self.board[self._flatten(vertex)] = utils.EMPTY - return True - - ### avoid suicide - can_kill, block = self._find_block(vertex) - if can_kill: - self.board[self._flatten(vertex)] = utils.EMPTY - return False - - self.board[self._flatten(vertex)] = utils.EMPTY - return True - - def _check_global_isomorphous(self, color, vertex): - ##backup - _board = copy.copy(self.board) - self.board[self._flatten(vertex)] = color - self._process_board(color, vertex) - if self.board in self.history: - res = True - else: - res = False - - self.board = _board - return res - - def _in_board(self, vertex): - x, y = vertex - if x < 1 or x > self.size: return False - if y < 1 or y > self.size: return False - return True - - def _neighbor(self, vertex): - x, y = vertex - nei = [] - for d in DELTA: - _x = x + d[0] - _y = y + d[1] - if self._in_board((_x, _y)): - nei.append((_x, _y)) - return nei - - def _corner(self, vertex): - x, y = vertex - corner = [] - for d in CORNER_OFFSET: - _x = x + d[0] - _y = y + d[1] - if self._in_board((_x, _y)): - corner.append((_x, _y)) - return corner - - def _process_board(self, color, vertex): - nei = self._neighbor(vertex) - for n in nei: - if self.board[self._flatten(n)] == utils.another_color(color): - can_kill, block = self._find_block(n, alive_break=True) - if can_kill: - for b in block: - self.board[self._flatten(b)] = utils.EMPTY - - def _find_group(self, start): - color = self.board[self._flatten(start)] - #print ("color : ", color) - chain = set() - frontier = [start] - while frontier: - current = frontier.pop() - #print ("current : ", current) - chain.add(current) - for n in self._neighbor(current): - #print n, self._flatten(n), self.board[self._flatten(n)], - if self.board[self._flatten(n)] == color and not n in chain: - frontier.append(n) - return chain - - def _is_eye(self, color, vertex): - nei = self._neighbor(vertex) - cor = self._corner(vertex) - ncolor = {color == self.board[self._flatten(n)] for n in nei} - if False in ncolor: - #print "not all neighbors are in same color with us" - return False - if set(nei) < self._find_group(nei[0]): - #print "all neighbors are in same group and same color with us" - return True - else: - opponent_number = [self.board[self._flatten(c)] for c in cor].count(-color) - opponent_propotion = float(opponent_number) / float(len(cor)) - if opponent_propotion < 0.5: - #print "few opponents, real eye" - return True - else: - #print "many opponents, fake eye" - return False - - # def is_valid(self, color, vertex): - def is_valid(self, state, action): - # state is the play board, the shape is [1, 9, 9, 17] - if action == self.size * self.size: - vertex = (0, 0) - else: - vertex = (action / self.size + 1, action % self.size + 1) - if state[0, 0, 0, -1] == utils.BLACK: - color = utils.BLACK - else: - color = utils.WHITE - self.history.clear() - for i in range(8): - self.history.append((state[:, :, :, i] - state[:, :, :, i + 8]).reshape(-1).tolist()) - self.board = copy.copy(self.history[-1]) - ### in board - if not self._in_board(vertex): - return False - - ### already have stone - if not self.board[self._flatten(vertex)] == utils.EMPTY: - # print(np.array(self.board).reshape(9, 9)) - # print(vertex) - return False - - ### check if it is qi - if not self._is_qi(color, vertex): - return False - - ### check if it is an eye of yourself - ### assumptions : notice that this judgement requires that the state is an endgame - #if self._is_eye(color, vertex): - # return False - - if self._check_global_isomorphous(color, vertex): - return False - - return True - - def do_move(self, color, vertex): - if vertex == utils.PASS: - return True - - id_ = self._flatten(vertex) - if self.board[id_] == utils.EMPTY: - self.board[id_] = color - self.history.append(copy.copy(self.board)) - return True - else: - return False - - def step_forward(self, state, action): - if state[0, 0, 0, -1] == 1: - color = 1 - else: - color = -1 - if action == 81: - vertex = (0, 0) - else: - vertex = (action % 9 + 1, action / 9 + 1) - # print(vertex) - # print(self.board) - self.board = (state[:, :, :, 7] - state[:, :, :, 15]).reshape(-1).tolist() - self.do_move(color, vertex) - new_state = np.concatenate( - [state[:, :, :, 1:8], (np.array(self.board) == 1).reshape(1, 9, 9, 1), - state[:, :, :, 9:16], (np.array(self.board) == -1).reshape(1, 9, 9, 1), - np.array(1 - state[:, :, :, -1]).reshape(1, 9, 9, 1)], - axis=3) - return new_state, 0 - - -pure_test = [ - 0, 1, 0, 1, 0, 1, 0, 0, 0, - 1, 0, 1, 0, 1, 0, 0, 0, 0, - 0, 1, 0, 1, 0, 0, 1, 0, 0, - 0, 0, 1, 0, 0, 1, 0, 1, 0, - 0, 0, 0, 0, 0, 1, 1, 1, 0, - 1, 1, 1, 0, 0, 0, 0, 0, 0, - 1, 0, 1, 0, 0, 1, 1, 0, 0, - 1, 1, 1, 0, 1, 0, 1, 0, 0, - 0, 0, 0, 0, 1, 1, 1, 0, 0 -] - -pt_qry = [(1, 1), (1, 5), (3, 3), (4, 7), (7, 2), (8, 6)] -pt_ans = [True, True, True, True, True, True] - -opponent_test = [ - 0, 1, 0, 1, 0, 1, 0,-1, 1, - 1,-1, 0,-1, 1,-1, 0, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 1,-1, 0, 1,-1, 1, 0, 0, - 1, 0, 1, 0, 1, 0, 1, 0, 0, - -1, 1, 1, 0, 1, 1, 1, 0, 0, - 0, 1,-1, 0,-1,-1,-1, 0, 0, - 1, 0, 1, 0,-1, 0,-1, 0, 0, - 0, 1, 0, 0,-1,-1,-1, 0, 0 -] -ot_qry = [(1, 1), (1, 5), (2, 9), (5, 2), (5, 6), (8, 2), (8, 6)] -ot_ans = [False, False, False, False, False, True, False] - -#print (ge._find_group((6, 1))) -#print ge._is_eye(utils.BLACK, pt_qry[0]) -ge = GoEnv() -ge._set_board(pure_test) -for i in range(6): - print (ge._is_eye(utils.BLACK, pt_qry[i])) -ge._set_board(opponent_test) -for i in range(7): - print (ge._is_eye(utils.BLACK, ot_qry[i])) diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index 8bb5f06..e8f3709 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -73,7 +73,7 @@ class UCTNode(MCTSNode): def valid_mask(self, simulator): # let all invalid actions be illeagel in mcts if self.mask is None: - self.mask = simulator.simulate_is_valid_list(self.state, range(self.action_num)) + self.mask = simulator.simulate_get_mask(self.state, range(self.action_num)) self.ucb[self.mask] = -float("Inf")