diff --git a/.gitignore b/.gitignore index d697b92..8ee6691 100644 --- a/.gitignore +++ b/.gitignore @@ -4,8 +4,8 @@ leela-zero parameters *.swp *.sublime* -checkpoints -checkpoints_origin +checkpoint *.json .DS_Store data +.log diff --git a/AlphaGo/.gitignore b/AlphaGo/.gitignore index 9c2fe16..e578e5a 100644 --- a/AlphaGo/.gitignore +++ b/AlphaGo/.gitignore @@ -1,3 +1,4 @@ data checkpoints checkpoints_origin +*.log diff --git a/AlphaGo/data_statistic.py b/AlphaGo/data_statistic.py new file mode 100644 index 0000000..6fedf1c --- /dev/null +++ b/AlphaGo/data_statistic.py @@ -0,0 +1,29 @@ +import os +import cPickle + +class Data(object): + def __init__(self): + self.boards = [] + self.probs = [] + self.winner = 0 + +def file_to_training_data(file_name): + with open(file_name, 'rb') as file: + try: + file.seek(0) + data = cPickle.load(file) + return data.winner + except Exception as e: + print(e) + return 0 + +if __name__ == "__main__": + win_count = [0, 0, 0] + file_list = os.listdir("./data") + #print file_list + for file in file_list: + win_count[file_to_training_data("./data/" + file)] += 1 + print "Total play : " + str(len(file_list)) + print "Black wins : " + str(win_count[1]) + print "White wins : " + str(win_count[-1]) + diff --git a/AlphaGo/engine.py b/AlphaGo/engine.py index 8b54470..5624a2f 100644 --- a/AlphaGo/engine.py +++ b/AlphaGo/engine.py @@ -6,6 +6,8 @@ # from game import Game +import copy +import numpy as np import utils @@ -183,10 +185,13 @@ class GTPEngine(): return 'unknown player', False def cmd_get_score(self, args, **kwargs): - return self._game.game_engine.executor_get_score(self._game.board, True), True + return self._game.game_engine.executor_get_score(self._game.board), True def cmd_show_board(self, args, **kwargs): - return self._game.board, True + board = copy.deepcopy(self._game.board) + if isinstance(board, np.ndarray): + board = board.flatten().tolist() + return board, True def cmd_get_prob(self, args, **kwargs): return self._game.prob, True @@ -194,4 +199,4 @@ class GTPEngine(): if __name__ == "main": game = Game() - engine = GTPEngine(game_obj=Game) + engine = GTPEngine(game_obj=game) diff --git a/AlphaGo/game.py b/AlphaGo/game.py index df08c0a..3a7959c 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -10,12 +10,15 @@ import copy import tensorflow as tf import numpy as np import sys, os -import go import model from collections import deque sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir)) from tianshou.core.mcts.mcts import MCTS +import go +import reversi +import time + class Game: ''' Load the real game and trained weights. @@ -23,23 +26,38 @@ class Game: TODO : Maybe merge with the engine class in future, currently leave it untouched for interacting with Go UI. ''' - def __init__(self, size=9, komi=3.75, checkpoint_path=None): - self.size = size - self.komi = komi - self.board = [utils.EMPTY] * (self.size ** 2) - self.history = [] - self.latest_boards = deque(maxlen=8) - for _ in range(8): + def __init__(self, name="reversi", role="unknown", debug=False, checkpoint_path=None): + self.name = name + self.role = role + self.debug = debug + if self.name == "go": + self.size = 9 + self.komi = 3.75 + self.history = [] + self.history_length = 8 + self.game_engine = go.Go(size=self.size, komi=self.komi, role=self.role) + self.board = [utils.EMPTY] * (self.size ** 2) + elif self.name == "reversi": + self.size = 8 + self.history_length = 1 + self.history = [] + self.game_engine = reversi.Reversi(size=self.size) + self.board = self.game_engine.get_board() + else: + raise ValueError(name + " is an unknown game...") + + self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length) + self.latest_boards = deque(maxlen=self.history_length) + for _ in range(self.history_length): self.latest_boards.append(self.board) - self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8, checkpoint_path=checkpoint_path) - # self.evaluator = lambda state: self.sess.run([tf.nn.softmax(self.net.p), self.net.v], - # feed_dict={self.net.x: state, self.net.is_training: False}) - self.game_engine = go.Go(size=self.size, komi=self.komi) def clear(self): - self.board = [utils.EMPTY] * (self.size ** 2) - self.history = [] - for _ in range(8): + if self.name == "go": + self.board = [utils.EMPTY] * (self.size ** 2) + self.history = [] + if self.name == "reversi": + self.board = self.game_engine.get_board() + for _ in range(self.history_length): self.latest_boards.append(self.board) def set_size(self, n): @@ -50,8 +68,9 @@ class Game: self.komi = k def think(self, latest_boards, color): - mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], self.size ** 2 + 1, inverse=True) - mcts.search(max_step=20) + mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], + self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True) + mcts.search(max_step=100) temp = 1 prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp) choice = np.random.choice(self.size ** 2 + 1, 1, p=prob).tolist()[0] @@ -65,7 +84,11 @@ class Game: # this function can be called directly to play the opponent's move if vertex == utils.PASS: return True - res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) + # TODO this implementation is not very elegant + if self.name == "go": + res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) + elif self.name == "reversi": + res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) return res def think_play_move(self, color): @@ -91,13 +114,14 @@ class Game: if row[i] < 10: print(' ', end='') for j in range(self.size): - print(self.status2symbol(self.board[self._flatten((j + 1, i + 1))]), end=' ') + print(self.status2symbol(self.board[self.game_engine._flatten((j + 1, i + 1))]), end=' ') print('') sys.stdout.flush() if __name__ == "__main__": - g = Game(checkpoint_path='./checkpoints/') - g.show_board() + g = Game("go") + print(g.board) + g.clear() g.think_play_move(1) #file = open("debug.txt", "a") #file.write("mcts check\n") diff --git a/AlphaGo/go.py b/AlphaGo/go.py index 661d918..55f5a4a 100644 --- a/AlphaGo/go.py +++ b/AlphaGo/go.py @@ -3,7 +3,7 @@ import utils import copy import numpy as np from collections import deque - +import time ''' Settings of the Go game. @@ -18,6 +18,7 @@ class Go: def __init__(self, **kwargs): self.size = kwargs['size'] self.komi = kwargs['komi'] + self.role = kwargs['role'] def _flatten(self, vertex): x, y = vertex @@ -98,7 +99,7 @@ class Go: def _check_global_isomorphous(self, history_boards, current_board, color, vertex): repeat = False - next_board = copy.copy(current_board) + next_board = copy.deepcopy(current_board) next_board[self._flatten(vertex)] = color self._process_board(next_board, color, vertex) if next_board in history_boards: @@ -157,7 +158,7 @@ class Go: vertex = self._deflatten(action) return vertex - def _is_valid(self, history_boards, current_board, color, vertex): + def _rule_check(self, history_boards, current_board, color, vertex): ### in board if not self._in_board(vertex): return False @@ -176,30 +177,30 @@ class Go: return True - def simulate_is_valid(self, state, action): + def _is_valid(self, state, action): history_boards, color = state vertex = self._action2vertex(action) current_board = history_boards[-1] - if not self._is_valid(history_boards, current_board, color, vertex): + if not self._rule_check(history_boards, current_board, color, vertex): return False if not self._knowledge_prunning(current_board, color, vertex): return False return True - def simulate_is_valid_list(self, state, action_set): + def simulate_get_mask(self, state, action_set): # find all the invalid actions - invalid_action_list = [] + invalid_action_mask = [] for action_candidate in action_set[:-1]: # go through all the actions excluding pass - if not self.simulate_is_valid(state, action_candidate): - invalid_action_list.append(action_candidate) - if len(invalid_action_list) < len(action_set) - 1: - invalid_action_list.append(action_set[-1]) + if not self._is_valid(state, action_candidate): + invalid_action_mask.append(action_candidate) + if len(invalid_action_mask) < len(action_set) - 1: + invalid_action_mask.append(action_set[-1]) # forbid pass, if we have other choices # TODO: In fact we should not do this. In some extreme cases, we should permit pass. - return invalid_action_list + return invalid_action_mask def _do_move(self, board, color, vertex): if vertex == utils.PASS: @@ -211,20 +212,23 @@ class Go: def simulate_step_forward(self, state, action): # initialize the simulate_board from state - history_boards, color = state - vertex = self._action2vertex(action) - new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex) - history_boards.append(new_board) - new_color = -color - return [history_boards, new_color], 0 + history_boards, color = copy.deepcopy(state) + if history_boards[-1] == history_boards[-2] and action is utils.PASS: + return None, 2 * (float(self.simple_executor_get_score(history_boards[-1]) > 0)-0.5) * color + else: + vertex = self._action2vertex(action) + new_board = self._do_move(copy.deepcopy(history_boards[-1]), color, vertex) + history_boards.append(new_board) + new_color = -color + return [history_boards, new_color], 0 def executor_do_move(self, history, latest_boards, current_board, color, vertex): - if not self._is_valid(history, current_board, color, vertex): + if not self._rule_check(history, current_board, color, vertex): return False current_board[self._flatten(vertex)] = color self._process_board(current_board, color, vertex) - history.append(copy.copy(current_board)) - latest_boards.append(copy.copy(current_board)) + history.append(copy.deepcopy(current_board)) + latest_boards.append(copy.deepcopy(current_board)) return True def _find_empty(self, current_board): @@ -280,11 +284,8 @@ class Go: elif color_estimate < 0: return utils.WHITE - def executor_get_score(self, current_board, is_unknown_estimation=False): - ''' - is_unknown_estimation: whether use nearby stone to predict the unknown - return score from BLACK perspective. - ''' + def executor_get_score(self, current_board): + #return score from BLACK perspective. _board = copy.deepcopy(current_board) while utils.EMPTY in _board: vertex = self._find_empty(_board) @@ -294,10 +295,8 @@ class Go: _board[self._flatten(vertex)] = utils.BLACK elif boarder_color == {utils.WHITE}: _board[self._flatten(vertex)] = utils.WHITE - elif is_unknown_estimation: - _board[self._flatten(vertex)] = self._predict_from_nearby(_board, vertex) else: - _board[self._flatten(vertex)] =utils.UNKNOWN + _board[self._flatten(vertex)] = self._predict_from_nearby(_board, vertex) score = 0 for i in _board: if i == utils.BLACK: @@ -308,7 +307,46 @@ class Go: return score + + def simple_executor_get_score(self, current_board): + ''' + can only be used for the empty group only have one single stone + return score from BLACK perspective. + ''' + score = 0 + for idx, color in enumerate(current_board): + if color == utils.EMPTY: + neighbors = self._neighbor(self._deflatten(idx)) + color = current_board[self._flatten(neighbors[0])] + if color == utils.BLACK: + score += 1 + elif color == utils.WHITE: + score -= 1 + score -= self.komi + return score + + if __name__ == "__main__": + go = Go(size=9, komi=3.75, role = utils.BLACK) + endgame = [ + 1, 0, 1, 0, 1, 1, -1, 0, -1, + 1, 1, 1, 1, 1, 1, -1, -1, -1, + 0, 1, 1, 1, 1, -1, 0, -1, 0, + 1, 1, 1, 1, 1, -1, -1, -1, -1, + 1, -1, 1, -1, 1, 1, -1, -1, -1, + -1, -1, -1, -1, -1, 1, -1, 0, -1, + 1, 1, 1, -1, -1, -1, -1, -1, -1, + 1, 0, 1, 1, 1, 1, 1, -1, 0, + 1, 1, 0, 1, -1, -1, -1, -1, -1 + ] + time0 = time.time() + score = go.executor_get_score(endgame) + time1 = time.time() + print(score, time1 - time0) + score = go.new_executor_get_score(endgame) + time2 = time.time() + print(score, time2 - time1) + ''' ### do unit test for Go class pure_test = [ 0, 1, 0, 1, 0, 1, 0, 0, 0, @@ -347,3 +385,4 @@ if __name__ == "__main__": for i in range(7): print (go._is_eye(opponent_test, utils.BLACK, ot_qry[i])) print("Test of eye surrend by opponents\n") + ''' diff --git a/AlphaGo/model.py b/AlphaGo/model.py index 22e8626..0549f41 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -1,7 +1,6 @@ import os import time -import random -import sys +import copy import cPickle from collections import deque @@ -102,7 +101,7 @@ class ResNet(object): self._build_network(residual_block_num, self.checkpoint_path) # training hyper-parameters: - self.window_length = 7000 + self.window_length = 3 self.save_freq = 5000 self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length), 'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)} @@ -153,6 +152,9 @@ class ResNet(object): :param color: a string, indicate which one to play :return: a list of tensor, the predicted value and policy given the history and color """ + # Note : maybe we can use it for isolating test of MCTS + #prob = [1.0 / self.action_num] * self.action_num + #return [prob, np.random.uniform(-1, 1)] history, color = state if len(history) != self.history_length: raise ValueError( @@ -171,10 +173,10 @@ class ResNet(object): """ state = np.zeros([1, self.board_size, self.board_size, 2 * self.history_length + 1]) for i in range(self.history_length): - state[0, :, :, i] = np.array(np.array(history[i]) == np.ones(self.board_size ** 2)).reshape(self.board_size, + state[0, :, :, i] = np.array(np.array(history[i]).flatten() == np.ones(self.board_size ** 2)).reshape(self.board_size, self.board_size) state[0, :, :, i + self.history_length] = np.array( - np.array(history[i]) == -np.ones(self.board_size ** 2)).reshape(self.board_size, self.board_size) + np.array(history[i]).flatten() == -np.ones(self.board_size ** 2)).reshape(self.board_size, self.board_size) # TODO: need a config to specify the BLACK and WHITE if color == +1: state[0, :, :, 2 * self.history_length] = np.ones([self.board_size, self.board_size]) @@ -224,11 +226,19 @@ class ResNet(object): else: start_time = time.time() for i in range(batch_size): - game_num = random.randint(0, self.window_length-1) - state_num = random.randint(0, self.training_data['length'][game_num]-1) - training_data['states'].append(np.expand_dims(self.training_data['states'][game_num][state_num], 0)) - training_data['probs'].append(np.expand_dims(self.training_data['probs'][game_num][state_num], 0)) - training_data['winner'].append(np.expand_dims(self.training_data['winner'][game_num][state_num], 0)) + priority = np.array(self.training_data['length']) / (0.0 + np.sum(np.array(self.training_data['length']))) + game_num = np.random.choice(self.window_length, 1, p=priority)[0] + state_num = np.random.randint(self.training_data['length'][game_num]) + rotate_times = np.random.randint(4) + reflect_times = np.random.randint(2) + reflect_orientation = np.random.randint(2) + training_data['states'].append( + self._preprocession(self.training_data['states'][game_num][state_num], reflect_times, + reflect_orientation, rotate_times)) + training_data['probs'].append(np.concatenate( + [self._preprocession(self.training_data['probs'][game_num][state_num][:-1].reshape(self.board_size, self.board_size, 1), reflect_times, + reflect_orientation, rotate_times).reshape(1, self.board_size**2), self.training_data['probs'][game_num][state_num][-1].reshape(1,1)], axis=1)) + training_data['winner'].append(self.training_data['winner'][game_num][state_num].reshape(1, 1)) value_loss, policy_loss, reg, _ = self.sess.run( [self.value_loss, self.policy_loss, self.reg, self.train_op], feed_dict={self.x: np.concatenate(training_data['states'], axis=0), @@ -280,6 +290,55 @@ class ResNet(object): winner = np.concatenate(winner, axis=0) return states, probs, winner + def _preprocession(self, board, reflect_times=0, reflect_orientation=0, rotate_times=0): + """ + preprocessing for augmentation + + :param board: a ndarray, board to process + :param reflect_times: an integer, how many times to reflect + :param reflect_orientation: an integer, which orientation to reflect + :param rotate_times: an integer, how many times to rotate + :return: + """ + + new_board = copy.deepcopy(board) + if new_board.ndim == 3: + new_board = np.expand_dims(new_board, axis=0) + + new_board = self._board_reflection(new_board, reflect_times, reflect_orientation) + new_board = self._board_rotation(new_board, rotate_times) + + return new_board + + def _board_rotation(self, board, times): + """ + rotate the board for augmentation + note that board's shape should be [batch_size, board_size, board_size, channels] + + :param board: a ndarray, shape [batch_size, board_size, board_size, channels] + :param times: an integer, how many times to rotate + :return: + """ + return np.rot90(board, times, (1, 2)) + + def _board_reflection(self, board, times, orientation): + """ + reflect the board for augmentation + note that board's shape should be [batch_size, board_size, board_size, channels] + + :param board: a ndarray, shape [batch_size, board_size, board_size, channels] + :param times: an integer, how many times to reflect + :param orientation: an integer, which orientation to reflect + :return: + """ + new_board = copy.deepcopy(board) + for _ in range(times): + if orientation == 0: + new_board = new_board[:, ::-1] + if orientation == 1: + new_board = new_board[:, :, ::-1] + return new_board + if __name__ == "__main__": model = ResNet(board_size=9, action_num=82, history_length=8) diff --git a/AlphaGo/play.py b/AlphaGo/play.py index 3681430..2731948 100644 --- a/AlphaGo/play.py +++ b/AlphaGo/play.py @@ -7,7 +7,6 @@ import time import os import cPickle - class Data(object): def __init__(self): self.boards = [] @@ -29,6 +28,7 @@ if __name__ == '__main__': parser.add_argument("--black_weight_path", type=str, default=None) parser.add_argument("--white_weight_path", type=str, default=None) parser.add_argument("--id", type=int, default=0) + parser.add_argument("--debug", type=bool, default=False) args = parser.parse_args() if not os.path.exists(args.result_path): @@ -61,11 +61,13 @@ if __name__ == '__main__': white_role_name = 'white' + str(args.id) agent_v0 = subprocess.Popen( - ['python', '-u', 'player.py', '--role=' + black_role_name, '--checkpoint_path=' + str(args.black_weight_path)], + ['python', '-u', 'player.py', '--role=' + black_role_name, + '--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) agent_v1 = subprocess.Popen( - ['python', '-u', 'player.py', '--role=' + white_role_name, '--checkpoint_path=' + str(args.white_weight_path)], + ['python', '-u', 'player.py', '--role=' + white_role_name, + '--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) server_list = "" @@ -87,27 +89,29 @@ if __name__ == '__main__': pattern = "[A-Z]{1}[0-9]{1}" space = re.compile("\s+") - size = 9 + size = {"go":9, "reversi":8} show = ['.', 'X', 'O'] evaluate_rounds = 1 game_num = 0 try: - while True: + #while True: + while game_num < evaluate_rounds: start_time = time.time() num = 0 pass_flag = [False, False] print("Start game {}".format(game_num)) # end the game if both palyer chose to pass, or play too much turns - while not (pass_flag[0] and pass_flag[1]) and num < size ** 2 * 2: + while not (pass_flag[0] and pass_flag[1]) and num < size["reversi"] ** 2 * 2: turn = num % 2 board = player[turn].run_cmd(str(num) + ' show_board') board = eval(board[board.index('['):board.index(']') + 1]) - for i in range(size): - for j in range(size): - print show[board[i * size + j]] + " ", + for i in range(size["reversi"]): + for j in range(size["reversi"]): + print show[board[i * size["reversi"] + j]] + " ", print "\n", data.boards.append(board) + start_time = time.time() move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn] + '\n') print role[turn] + " : " + str(move), num += 1 diff --git a/AlphaGo/player.py b/AlphaGo/player.py index 0e3daff..66a487f 100644 --- a/AlphaGo/player.py +++ b/AlphaGo/player.py @@ -25,16 +25,20 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--checkpoint_path", type=str, default=None) parser.add_argument("--role", type=str, default="unknown") + parser.add_argument("--debug", type=str, default=False) args = parser.parse_args() if args.checkpoint_path == 'None': args.checkpoint_path = None - game = Game(checkpoint_path=args.checkpoint_path) + debug = False + if args.debug == "True": + debug = True + game = Game(role=args.role, checkpoint_path=args.checkpoint_path, debug=debug) engine = GTPEngine(game_obj=game, name='tianshou', version=0) daemon = Pyro4.Daemon() # make a Pyro daemon ns = Pyro4.locateNS() # find the name server - player = Player(role = args.role, engine = engine) + player = Player(role=args.role, engine=engine) print "Init " + args.role + " player finished" uri = daemon.register(player) # register the greeting maker as a Pyro object print "Start on name " + args.role diff --git a/AlphaGo/reversi.py b/AlphaGo/reversi.py index cba91d9..c6c8a5b 100644 --- a/AlphaGo/reversi.py +++ b/AlphaGo/reversi.py @@ -1,264 +1,194 @@ -from __future__ import print_function -import numpy as np - -''' -Settings of the Go game. - -(1, 1) is considered as the upper left corner of the board, -(size, 1) is the lower left -''' - - -def find_correct_moves(own, enemy): - """return legal moves""" - left_right_mask = 0x7e7e7e7e7e7e7e7e # Both most left-right edge are 0, else 1 - top_bottom_mask = 0x00ffffffffffff00 # Both most top-bottom edge are 0, else 1 - mask = left_right_mask & top_bottom_mask - mobility = 0 - mobility |= search_offset_left(own, enemy, left_right_mask, 1) # Left - mobility |= search_offset_left(own, enemy, mask, 9) # Left Top - mobility |= search_offset_left(own, enemy, top_bottom_mask, 8) # Top - mobility |= search_offset_left(own, enemy, mask, 7) # Top Right - mobility |= search_offset_right(own, enemy, left_right_mask, 1) # Right - mobility |= search_offset_right(own, enemy, mask, 9) # Bottom Right - mobility |= search_offset_right(own, enemy, top_bottom_mask, 8) # Bottom - mobility |= search_offset_right(own, enemy, mask, 7) # Left bottom - return mobility - - -def calc_flip(pos, own, enemy): - """return flip stones of enemy by bitboard when I place stone at pos. - - :param pos: 0~63 - :param own: bitboard (0=top left, 63=bottom right) - :param enemy: bitboard - :return: flip stones of enemy when I place stone at pos. - """ - f1 = _calc_flip_half(pos, own, enemy) - f2 = _calc_flip_half(63 - pos, rotate180(own), rotate180(enemy)) - return f1 | rotate180(f2) - - -def _calc_flip_half(pos, own, enemy): - el = [enemy, enemy & 0x7e7e7e7e7e7e7e7e, enemy & 0x7e7e7e7e7e7e7e7e, enemy & 0x7e7e7e7e7e7e7e7e] - masks = [0x0101010101010100, 0x00000000000000fe, 0x0002040810204080, 0x8040201008040200] - masks = [b64(m << pos) for m in masks] - flipped = 0 - for e, mask in zip(el, masks): - outflank = mask & ((e | ~mask) + 1) & own - flipped |= (outflank - (outflank != 0)) & mask - return flipped - - -def search_offset_left(own, enemy, mask, offset): - e = enemy & mask - blank = ~(own | enemy) - t = e & (own >> offset) - t |= e & (t >> offset) - t |= e & (t >> offset) - t |= e & (t >> offset) - t |= e & (t >> offset) - t |= e & (t >> offset) # Up to six stones can be turned at once - return blank & (t >> offset) # Only the blank squares can be started - - -def search_offset_right(own, enemy, mask, offset): - e = enemy & mask - blank = ~(own | enemy) - t = e & (own << offset) - t |= e & (t << offset) - t |= e & (t << offset) - t |= e & (t << offset) - t |= e & (t << offset) - t |= e & (t << offset) # Up to six stones can be turned at once - return blank & (t << offset) # Only the blank squares can be started - - -def flip_vertical(x): - k1 = 0x00FF00FF00FF00FF - k2 = 0x0000FFFF0000FFFF - x = ((x >> 8) & k1) | ((x & k1) << 8) - x = ((x >> 16) & k2) | ((x & k2) << 16) - x = (x >> 32) | b64(x << 32) - return x - - -def b64(x): - return x & 0xFFFFFFFFFFFFFFFF - - -def bit_count(x): - return bin(x).count('1') - - -def bit_to_array(x, size): - """bit_to_array(0b0010, 4) -> array([0, 1, 0, 0])""" - return np.array(list(reversed((("0" * size) + bin(x)[2:])[-size:])), dtype=np.uint8) - - -def flip_diag_a1h8(x): - k1 = 0x5500550055005500 - k2 = 0x3333000033330000 - k4 = 0x0f0f0f0f00000000 - t = k4 & (x ^ b64(x << 28)) - x ^= t ^ (t >> 28) - t = k2 & (x ^ b64(x << 14)) - x ^= t ^ (t >> 14) - t = k1 & (x ^ b64(x << 7)) - x ^= t ^ (t >> 7) - return x - - -def rotate90(x): - return flip_diag_a1h8(flip_vertical(x)) - - -def rotate180(x): - return rotate90(rotate90(x)) - - -class Reversi: - def __init__(self, black=None, white=None): - self.black = black or (0b00001000 << 24 | 0b00010000 << 32) - self.white = white or (0b00010000 << 24 | 0b00001000 << 32) - self.board = None # 8 * 8 board with 1 for black, -1 for white and 0 for blank - self.color = None # 1 for black and -1 for white - self.action = None # number in 0~63 - # self.winner = None - self.black_win = None - - def get_board(self, black=None, white=None): - self.black = black or (0b00001000 << 24 | 0b00010000 << 32) - self.white = white or (0b00010000 << 24 | 0b00001000 << 32) - self.board = self.bitboard2board() - return self.board - - def simulate_is_valid(self, board, color): - self.board = board - self.color = color - self.board2bitboard() - own, enemy = self.get_own_and_enemy() - mobility = find_correct_moves(own, enemy) - valid_moves = bit_to_array(mobility, 64) - valid_moves = np.argwhere(valid_moves) - valid_moves = list(np.reshape(valid_moves, len(valid_moves))) - return valid_moves - - def simulate_step_forward(self, state, vertex): - self.board = state[0] - self.color = state[1] - self.board2bitboard() - self.vertex2action(vertex) - step_forward = self.step() - if step_forward: - new_board = self.bitboard2board() - return [new_board, 0 - self.color], 0 - - def executor_do_move(self, board, color, vertex): - self.board = board - self.color = color - self.board2bitboard() - self.vertex2action(vertex) - step_forward = self.step() - if step_forward: - new_board = self.bitboard2board() - for i in range(64): - board[i] = new_board[i] - - def executor_get_score(self, board): - self.board = board - self._game_over() - if self.black_win is not None: - return self.black_win - else: - ValueError("Game not finished!") - - def board2bitboard(self): - count = 1 - if self.board is None: - ValueError("None board!") - self.black = 0 - self.white = 0 - for i in range(64): - if self.board[i] == 1: - self.black |= count - elif self.board[i] == -1: - self.white |= count - count *= 2 - - def vertex2action(self, vertex): - x, y = vertex - if x == 0 and y == 0: - self.action = None - else: - self.action = 8 * (x - 1) + y - 1 - - def bitboard2board(self): - board = [] - black = bit_to_array(self.black, 64) - white = bit_to_array(self.white, 64) - for i in range(64): - if black[i]: - board.append(1) - elif white[i]: - board.append(-1) - else: - board.append(0) - return board - - def step(self): - if self.action < 0 or self.action > 63: - ValueError("Wrong action!") - if self.action is None: - return False - - own, enemy = self.get_own_and_enemy() - - flipped = calc_flip(self.action, own, enemy) - if bit_count(flipped) == 0: - self.illegal_move_to_lose(self.action) - return False - own ^= flipped - own |= 1 << self.action - enemy ^= flipped - - self.set_own_and_enemy(own, enemy) - return True - - def _game_over(self): - # self.done = True - ''' - if self.winner is None: - black_num, white_num = self.number_of_black_and_white - if black_num > white_num: - self.winner = 1 - elif black_num < white_num: - self.winner = -1 - else: - self.winner = 0 - ''' - if self.black_win is None: - black_num, white_num = self.number_of_black_and_white - self.black_win = black_num - white_num - - def illegal_move_to_lose(self, action): - self._game_over() - - def get_own_and_enemy(self): - if self.color == 1: - own, enemy = self.black, self.white - elif self.color == -1: - own, enemy = self.white, self.black - else: - own, enemy = None, None - return own, enemy - - def set_own_and_enemy(self, own, enemy): - if self.color == 1: - self.black, self.white = own, enemy - else: - self.white, self.black = own, enemy - - @property - def number_of_black_and_white(self): - return bit_count(self.black), bit_count(self.white) +import numpy as np +import copy +''' +Settings of the Reversi game. + +(1, 1) is considered as the upper left corner of the board, +(size, 1) is the lower left +''' + + +class Reversi: + def __init__(self, **kwargs): + self.size = kwargs['size'] + + def _deflatten(self, idx): + x = idx // self.size + 1 + y = idx % self.size + 1 + return (x, y) + + def _flatten(self, vertex): + x, y = vertex + if (x == 0) and (y == 0): + return self.size ** 2 + return (x - 1) * self.size + (y - 1) + + def get_board(self): + board = np.zeros([self.size, self.size], dtype=np.int32) + board[self.size / 2 - 1, self.size / 2 - 1] = -1 + board[self.size / 2, self.size / 2] = -1 + board[self.size / 2 - 1, self.size / 2] = 1 + board[self.size / 2, self.size / 2 - 1] = 1 + return board + + def _find_correct_moves(self, board, color, is_next=False): + moves = [] + if is_next: + new_color = 0 - color + else: + new_color = color + for i in range(self.size ** 2): + x, y = self._deflatten(i) + valid = self._is_valid(board, x - 1, y - 1, new_color) + if valid: + moves.append(i) + return moves + + def _one_direction_valid(self, board, x, y, color): + if (x >= 0) and (x < self.size): + if (y >= 0) and (y < self.size): + if board[x, y] == color: + return True + return False + + def _is_valid(self, board, x, y, color): + if board[x, y]: + return False + for x_direction in [-1, 0, 1]: + for y_direction in [-1, 0, 1]: + new_x = x + new_y = y + flag = 0 + while True: + new_x += x_direction + new_y += y_direction + if self._one_direction_valid(board, new_x, new_y, 0 - color): + flag = 1 + else: + break + if self._one_direction_valid(board, new_x, new_y, color) and flag: + return True + return False + + def simulate_get_mask(self, state, action_set): + history_boards, color = copy.deepcopy(state) + board = copy.deepcopy(history_boards[-1]) + valid_moves = self._find_correct_moves(board, color) + if not len(valid_moves): + invalid_action_mask = action_set[0:-1] + else: + invalid_action_mask = [] + for action in action_set: + if action not in valid_moves: + invalid_action_mask.append(action) + return invalid_action_mask + + def simulate_step_forward(self, state, action): + history_boards, color = copy.deepcopy(state) + board = copy.deepcopy(history_boards[-1]) + if action == self.size ** 2: + valid_moves = self._find_correct_moves(board, color, is_next=True) + if not len(valid_moves): + winner = self._get_winner(board) + return None, winner * color + else: + return [history_boards, 0 - color], 0 + new_board = self._step(board, color, action) + history_boards.append(new_board) + return [history_boards, 0 - color], 0 + + def _get_winner(self, board): + black_num, white_num = self._number_of_black_and_white(board) + black_win = black_num - white_num + if black_win > 0: + winner = 1 + elif black_win < 0: + winner = -1 + else: + winner = 0 + return winner + + def _number_of_black_and_white(self, board): + black_num = 0 + white_num = 0 + board_list = np.reshape(board, self.size ** 2) + for i in range(len(board_list)): + if board_list[i] == 1: + black_num += 1 + elif board_list[i] == -1: + white_num += 1 + return black_num, white_num + + def _step(self, board, color, action): + if action < 0 or action > self.size ** 2 - 1: + raise ValueError("Action not in the range of [0,63]!") + if action is None: + raise ValueError("Action is None!") + x, y = self._deflatten(action) + new_board = self._flip(board, x - 1, y - 1, color) + return new_board + + def _flip(self, board, x, y, color): + valid = 0 + board[x, y] = color + for x_direction in [-1, 0, 1]: + for y_direction in [-1, 0, 1]: + new_x = x + new_y = y + flag = 0 + while True: + new_x += x_direction + new_y += y_direction + if self._one_direction_valid(board, new_x, new_y, 0 - color): + flag = 1 + else: + break + if self._one_direction_valid(board, new_x, new_y, color) and flag: + valid = 1 + flip_x = x + flip_y = y + while True: + flip_x += x_direction + flip_y += y_direction + if self._one_direction_valid(board, flip_x, flip_y, 0 - color): + board[flip_x, flip_y] = color + else: + break + if valid: + return board + else: + raise ValueError("Invalid action") + + def executor_do_move(self, history, latest_boards, board, color, vertex): + board = np.reshape(board, (self.size, self.size)) + color = color + action = self._flatten(vertex) + if action == self.size ** 2: + valid_moves = self._find_correct_moves(board, color, is_next=True) + if not len(valid_moves): + return False + else: + return True + else: + new_board = self._step(board, color, action) + history.append(new_board) + latest_boards.append(new_board) + return True + + def executor_get_score(self, board): + board = board + winner = self._get_winner(board) + return winner + + +if __name__ == "__main__": + reversi = Reversi() + # board = reversi.get_board() + # print(board) + # state, value = reversi.simulate_step_forward([board, -1], 20) + # print(state[0]) + # print("board") + # print(board) + # r = reversi.executor_get_score(board) + # print(r) + diff --git a/AlphaGo/self-play.py b/AlphaGo/self-play.py index 4387b24..dd03b13 100644 --- a/AlphaGo/self-play.py +++ b/AlphaGo/self-play.py @@ -79,7 +79,7 @@ while True: prob.append(np.array(game.prob).reshape(-1, game.size ** 2 + 1)) print("Finished") print("\n") - score = game.game_engine.executor_get_score(game.board, True) + score = game.game_engine.executor_get_score(game.board) if score > 0: winner = utils.BLACK else: diff --git a/README.md b/README.md index 9c3af16..fc7d494 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,11 @@ Tianshou(天授) is a reinforcement learning platform. The following image illus +## examples + +During development, run examples under `./examples/` directory with, e.g. `python ppo_example.py`. +Running them under this directory with `python examples/ppo_example.py` will not work. + ## About coding style diff --git a/examples/dqn_example.py b/examples/dqn_example.py index b676475..cf20d66 100644 --- a/examples/dqn_example.py +++ b/examples/dqn_example.py @@ -1,8 +1,6 @@ #!/usr/bin/env python import tensorflow as tf -import numpy as np -import time import gym # our lib imports here! @@ -10,7 +8,7 @@ import sys sys.path.append('..') import tianshou.core.losses as losses from tianshou.data.replay_buffer.utils import get_replay_buffer -import tianshou.core.policy as policy +import tianshou.core.policy.dqn as policy def policy_net(observation, action_dim): @@ -41,6 +39,8 @@ if __name__ == '__main__': # pass the observation variable to the replay buffer or find a more reasonable way to help replay buffer # access this observation variable. observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input + action = tf.placeholder(dtype=tf.int32, shape=(None,)) # batch of integer actions + with tf.variable_scope('q_net'): q_values = policy_net(observation, action_dim) @@ -48,10 +48,9 @@ if __name__ == '__main__': q_values_target = policy_net(observation, action_dim) # 2. build losses, optimizers - q_net = policy.DQN(q_values, observation_placeholder=observation) # YongRen: policy.DQN - target_net = policy.DQN(q_values_target, observation_placeholder=observation) + q_net = policy.DQNRefactor(q_values, observation_placeholder=observation, action_placeholder=action) # YongRen: policy.DQN + target_net = policy.DQNRefactor(q_values_target, observation_placeholder=observation, action_placeholder=action) - action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen diff --git a/examples/ppo_example.py b/examples/ppo_example.py index 02ccb52..985c8f2 100755 --- a/examples/ppo_example.py +++ b/examples/ppo_example.py @@ -1,17 +1,16 @@ #!/usr/bin/env python +from __future__ import absolute_import import tensorflow as tf -import numpy as np -import time import gym # our lib imports here! import sys sys.path.append('..') -import tianshou.core.losses as losses +from tianshou.core import losses from tianshou.data.batch import Batch import tianshou.data.advantage_estimation as advantage_estimation -import tianshou.core.policy as policy +import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy def policy_net(observation, action_dim, scope=None): diff --git a/tianshou/core/README.md b/tianshou/core/README.md index 3617525..a9cda58 100644 --- a/tianshou/core/README.md +++ b/tianshou/core/README.md @@ -21,4 +21,8 @@ referencing QValuePolicy in base.py, should have at least the listed methods. TongzhengRen -seems to be direct python functions. Though the management of placeholders may require some discussion. also may write it in a functional form. \ No newline at end of file +seems to be direct python functions. Though the management of placeholders may require some discussion. also may write it in a functional form. + +# policy, value_function + +naming should be reconsidered. Perhaps use plural forms for all nouns \ No newline at end of file diff --git a/tianshou/core/losses.py b/tianshou/core/losses.py index 3461afb..5d5d2f3 100644 --- a/tianshou/core/losses.py +++ b/tianshou/core/losses.py @@ -35,17 +35,16 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"): # TODO: Different baseline methods like REINFORCE, etc. return vanilla_policy_gradient_loss -def dqn_loss(sampled_action, sampled_target, q_net): +def dqn_loss(sampled_action, sampled_target, policy): """ deep q-network :param sampled_action: placeholder of sampled actions during the interaction with the environment :param sampled_target: estimated Q(s,a) - :param q_net: current `policy` to be optimized + :param policy: current `policy` to be optimized :return: """ - action_num = q_net.values_tensor().get_shape()[1] - sampled_q = tf.reduce_sum(q_net.values_tensor() * tf.one_hot(sampled_action, action_num), axis=1) + sampled_q = policy.q_net.value_tensor return tf.reduce_mean(tf.square(sampled_target - sampled_q)) def deterministic_policy_gradient(sampled_state, critic): diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index 16890d7..1ba1145 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -25,8 +25,9 @@ class MCTSNode(object): def valid_mask(self, simulator): pass + class UCTNode(MCTSNode): - def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5): + def __init__(self, parent, action, state, action_num, prior, debug=False, inverse=False, c_puct = 5): super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse) self.Q = np.zeros([action_num]) self.W = np.zeros([action_num]) @@ -34,9 +35,16 @@ class UCTNode(MCTSNode): self.c_puct = c_puct self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1) self.mask = None + self.debug=debug + self.elapse_time = 0 + + def clear_elapse_time(self): + self.elapse_time = 0 def selection(self, simulator): + head = time.time() self.valid_mask(simulator) + self.elapse_time += time.time() - head action = np.argmax(self.ucb) if action in self.children.keys(): return self.children[action].selection(simulator) @@ -59,10 +67,13 @@ class UCTNode(MCTSNode): self.parent.backpropagation(self.children[action].reward) def valid_mask(self, simulator): - # let all invalid actions be illeagel in mcts - if self.mask is None: - self.mask = simulator.simulate_is_valid_list(self.state, range(self.action_num)) - self.ucb[self.mask] = -float("Inf") + # let all invalid actions be illegal in mcts + if not hasattr(simulator, 'simulate_get_mask'): + pass + else: + if self.mask is None: + self.mask = simulator.simulate_get_mask(self.state, range(self.action_num)) + self.ucb[self.mask] = -float("Inf") class TSNode(MCTSNode): @@ -87,15 +98,15 @@ class ActionNode(object): self.reward = 0 def type_conversion_to_tuple(self): - if type(self.next_state) is np.ndarray: + if isinstance(self.next_state, np.ndarray): self.next_state = self.next_state.tolist() - if type(self.next_state) is list: + if isinstance(self.next_state, list): self.next_state = list2tuple(self.next_state) def type_conversion_to_origin(self): - if self.state_type is np.ndarray: + if isinstance(self.state_type, np.ndarray): self.next_state = np.array(self.next_state) - if self.state_type is list: + if isinstance(self.state_type, np.ndarray): self.next_state = tuple2list(self.next_state) def selection(self, simulator): @@ -126,15 +137,18 @@ class ActionNode(object): class MCTS(object): - def __init__(self, simulator, evaluator, root, action_num, method="UCT", inverse=False): + def __init__(self, simulator, evaluator, root, action_num, method="UCT", + role="unknown", debug=False, inverse=False): self.simulator = simulator self.evaluator = evaluator + self.role = role + self.debug = debug prior, _ = self.evaluator(root) self.action_num = action_num if method == "": self.root = root if method == "UCT": - self.root = UCTNode(None, None, root, action_num, prior, inverse=inverse) + self.root = UCTNode(None, None, root, action_num, prior, self.debug, inverse=inverse) if method == "TS": self.root = TSNode(None, None, root, action_num, prior, inverse=inverse) self.inverse = inverse @@ -149,14 +163,36 @@ class MCTS(object): if max_step is None and max_time is None: raise ValueError("Need a stop criteria!") + selection_time = 0 + expansion_time = 0 + backprop_time = 0 + self.root.clear_elapse_time() while step < max_step and time.time() - start_time < max_step: - self._expand() + sel_time, exp_time, back_time = self._expand() + selection_time += sel_time + expansion_time += exp_time + backprop_time += back_time step += 1 + if (self.debug): + file = open("debug.txt", "a") + file.write("[" + str(self.role) + "]" + + " selection : " + str(selection_time) + "\t" + + " validmask : " + str(self.root.elapse_time) + "\t" + + " expansion : " + str(expansion_time) + "\t" + + " backprop : " + str(backprop_time) + "\t" + + "\n") + file.close() def _expand(self): + t0 = time.time() node, new_action = self.root.selection(self.simulator) + t1 = time.time() value = node.children[new_action].expansion(self.evaluator, self.action_num) + t2 = time.time() node.children[new_action].backpropagation(value + 0.) + t3 = time.time() + return t1 - t0, t2 - t1, t3 - t2 + if __name__ == "__main__": pass diff --git a/tianshou/core/policy/__init__.py b/tianshou/core/policy/__init__.py index ccde775..e69de29 100644 --- a/tianshou/core/policy/__init__.py +++ b/tianshou/core/policy/__init__.py @@ -1,6 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from .base import * -from .stochastic import * -from .dqn import * \ No newline at end of file diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index 025abd5..1c1e1c5 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -3,21 +3,26 @@ from __future__ import absolute_import from __future__ import division -import warnings import tensorflow as tf # from zhusuan.utils import add_name_scope -__all__ = [ - 'StochasticPolicy', - 'QValuePolicy', -] - # TODO: a even more "base" class for policy +class PolicyBase(object): + """ + base class for policy. only provides `act` method with exploration + """ + def __init__(self, observation_placeholder): + self._observation_placeholder = observation_placeholder + + def act(self, observation, exploration): + raise NotImplementedError() + + class QValuePolicy(object): """ The policy as in DQN @@ -25,14 +30,14 @@ class QValuePolicy(object): def __init__(self, observation_placeholder): self._observation_placeholder = observation_placeholder - def act(self, observation, exploration=None): # first implement no exploration + def act(self, observation, exploration=None): # first implement no exploration """ return the action (int) to be executed. no exploration when exploration=None. """ self._act(observation, exploration) - def _act(self, observation, exploration = None): + def _act(self, observation, exploration=None): raise NotImplementedError() def values(self, observation): @@ -48,7 +53,6 @@ class QValuePolicy(object): pass - class StochasticPolicy(object): """ The :class:`Distribution` class is the base class for various probabilistic @@ -118,7 +122,7 @@ class StochasticPolicy(object): param_dtype, is_continuous, observation_placeholder, - group_ndims=0, # maybe useful for repeat_action + group_ndims=0, # maybe useful for repeat_action **kwargs): self._act_dtype = act_dtype diff --git a/tianshou/core/policy/dqn.py b/tianshou/core/policy/dqn.py index d03dbd4..8533549 100644 --- a/tianshou/core/policy/dqn.py +++ b/tianshou/core/policy/dqn.py @@ -1,19 +1,34 @@ -from tianshou.core.policy.base import QValuePolicy +from __future__ import absolute_import + +from .base import PolicyBase import tensorflow as tf -import sys -sys.path.append('..') -import value_function.action_value as value_func +from ..value_function.action_value import DQN -class DQN_refactor(object): +class DQNRefactor(PolicyBase): """ use DQN from value_function as a member """ def __init__(self, value_tensor, observation_placeholder, action_placeholder): - self._network = value_func.DQN(value_tensor, observation_placeholder, action_placeholder) + self._q_net = DQN(value_tensor, observation_placeholder, action_placeholder) + self._argmax_action = tf.argmax(value_tensor, axis=1) + + super(DQNRefactor, self).__init__(observation_placeholder=observation_placeholder) + + def act(self, observation, exploration=None): + sess = tf.get_default_session() + if not exploration: # no exploration + action = sess.run(self._argmax_action, feed_dict={self._observation_placeholder: observation}) -class DQN(QValuePolicy): + return action + + @property + def q_net(self): + return self._q_net + + +class DQNOld(QValuePolicy): """ The policy as in DQN """ diff --git a/tianshou/core/policy/stochastic.py b/tianshou/core/policy/stochastic.py index 3ef463e..d7a75d7 100644 --- a/tianshou/core/policy/stochastic.py +++ b/tianshou/core/policy/stochastic.py @@ -10,12 +10,6 @@ import tensorflow as tf from .base import StochasticPolicy -__all__ = [ - 'OnehotCategorical', - 'OnehotDiscrete', -] - - class OnehotCategorical(StochasticPolicy): """ The class of one-hot Categorical distribution. diff --git a/tianshou/core/value_function/action_value.py b/tianshou/core/value_function/action_value.py index cb8acc8..c62dae6 100644 --- a/tianshou/core/value_function/action_value.py +++ b/tianshou/core/value_function/action_value.py @@ -1,4 +1,6 @@ -from base import ValueFunctionBase +from __future__ import absolute_import + +from .base import ValueFunctionBase import tensorflow as tf @@ -13,9 +15,8 @@ class ActionValue(ValueFunctionBase): observation_placeholder=observation_placeholder ) - def get_value(self, observation, action): + def eval_value(self, observation, action): """ - :param observation: numpy array of observations, of shape (batchsize, observation_dim). :param action: numpy array of actions, of shape (batchsize, action_dim) # TODO: Atari discrete action should have dim 1. Super Mario may should have, say, dim 5, where each can be 0/1 @@ -23,8 +24,8 @@ class ActionValue(ValueFunctionBase): # TODO: dealing with the last dim of 1 in V(s) and Q(s, a) """ sess = tf.get_default_session() - return sess.run(self.get_value_tensor(), feed_dict= - {self._observation_placeholder: observation, self._action_placeholder:action})[:, 0] + return sess.run(self.value_tensor, feed_dict= + {self._observation_placeholder: observation, self._action_placeholder: action}) class DQN(ActionValue): @@ -39,15 +40,24 @@ class DQN(ActionValue): :param action_placeholder: of shape (batchsize, ) """ self._value_tensor_all_actions = value_tensor - canonical_value_tensor = value_tensor[action_placeholder] # maybe a tf.map_fn. for now it's wrong + + batch_size = tf.shape(value_tensor)[0] + batch_dim_index = tf.range(batch_size) + indices = tf.stack([batch_dim_index, action_placeholder], axis=1) + canonical_value_tensor = tf.gather_nd(value_tensor, indices) super(DQN, self).__init__(value_tensor=canonical_value_tensor, observation_placeholder=observation_placeholder, action_placeholder=action_placeholder) - def get_value_all_actions(self, observation): + def eval_value_all_actions(self, observation): + """ + :param observation: + :return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions) + """ sess = tf.get_default_session() return sess.run(self._value_tensor_all_actions, feed_dict={self._observation_placeholder: observation}) - def get_value_tensor_all_actions(self): + @property + def value_tensor_all_actions(self): return self._value_tensor_all_actions \ No newline at end of file diff --git a/tianshou/core/value_function/base.py b/tianshou/core/value_function/base.py index 0b27759..8ca9dd0 100644 --- a/tianshou/core/value_function/base.py +++ b/tianshou/core/value_function/base.py @@ -1,3 +1,6 @@ +from __future__ import absolute_import + +import tensorflow as tf # TODO: linear feature baseline also in tf? class ValueFunctionBase(object): @@ -6,16 +9,17 @@ class ValueFunctionBase(object): """ def __init__(self, value_tensor, observation_placeholder): self._observation_placeholder = observation_placeholder - self._value_tensor = value_tensor + self._value_tensor = tf.squeeze(value_tensor) # canonical values has shape (batchsize, ) - def get_value(self, **kwargs): + def eval_value(self, **kwargs): """ :return: batch of corresponding values in numpy array """ raise NotImplementedError() - def get_value_tensor(self): + @property + def value_tensor(self): """ :return: tensor of the corresponding values diff --git a/tianshou/core/value_function/state_value.py b/tianshou/core/value_function/state_value.py index 04fe442..02c12fe 100644 --- a/tianshou/core/value_function/state_value.py +++ b/tianshou/core/value_function/state_value.py @@ -1,4 +1,6 @@ -from base import ValueFunctionBase +from __future__ import absolute_import + +from .base import ValueFunctionBase import tensorflow as tf @@ -12,12 +14,12 @@ class StateValue(ValueFunctionBase): observation_placeholder=observation_placeholder ) - def get_value(self, observation): + def eval_value(self, observation): """ :param observation: numpy array of observations, of shape (batchsize, observation_dim). :return: numpy array of state values, of shape (batchsize, ) - # TODO: dealing with the last dim of 1 in V(s) and Q(s, a) + # TODO: dealing with the last dim of 1 in V(s) and Q(s, a), this should rely on the action shape returned by env """ sess = tf.get_default_session() - return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation})[:, 0] \ No newline at end of file + return sess.run(self.value_tensor, feed_dict={self._observation_placeholder: observation}) \ No newline at end of file