add some code for debug and profiling

This commit is contained in:
Dong Yan 2017-12-24 01:07:46 +08:00
parent 162aa313b6
commit 426251e158
6 changed files with 60 additions and 11 deletions

View File

@ -17,6 +17,7 @@ from tianshou.core.mcts.mcts import MCTS
import go import go
import reversi import reversi
import time
class Game: class Game:
''' '''
@ -25,8 +26,10 @@ class Game:
TODO : Maybe merge with the engine class in future, TODO : Maybe merge with the engine class in future,
currently leave it untouched for interacting with Go UI. currently leave it untouched for interacting with Go UI.
''' '''
def __init__(self, name="go", checkpoint_path=None): def __init__(self, name="go", role="unknown", debug=False, checkpoint_path=None):
self.name = name self.name = name
self.role = role
self.debug = debug
if self.name == "go": if self.name == "go":
self.size = 9 self.size = 9
self.komi = 3.75 self.komi = 3.75
@ -36,7 +39,7 @@ class Game:
self.latest_boards = deque(maxlen=8) self.latest_boards = deque(maxlen=8)
for _ in range(8): for _ in range(8):
self.latest_boards.append(self.board) self.latest_boards.append(self.board)
self.game_engine = go.Go(size=self.size, komi=self.komi) self.game_engine = go.Go(size=self.size, komi=self.komi, role=self.role)
elif self.name == "reversi": elif self.name == "reversi":
self.size = 8 self.size = 8
self.history_length = 1 self.history_length = 1
@ -61,7 +64,8 @@ class Game:
self.komi = k self.komi = k
def think(self, latest_boards, color): def think(self, latest_boards, color):
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], self.size ** 2 + 1, inverse=True) mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color],
self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True)
mcts.search(max_step=100) mcts.search(max_step=100)
temp = 1 temp = 1
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp) prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)

View File

@ -18,6 +18,7 @@ class Go:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.size = kwargs['size'] self.size = kwargs['size']
self.komi = kwargs['komi'] self.komi = kwargs['komi']
self.role = kwargs['role']
def _flatten(self, vertex): def _flatten(self, vertex):
x, y = vertex x, y = vertex

View File

@ -152,6 +152,9 @@ class ResNet(object):
:param color: a string, indicate which one to play :param color: a string, indicate which one to play
:return: a list of tensor, the predicted value and policy given the history and color :return: a list of tensor, the predicted value and policy given the history and color
""" """
# Note : maybe we can use it for isolating test of MCTS
#prob = [1.0 / self.action_num] * self.action_num
#return [prob, np.random.uniform(-1, 1)]
history, color = state history, color = state
if len(history) != self.history_length: if len(history) != self.history_length:
raise ValueError( raise ValueError(

View File

@ -28,6 +28,7 @@ if __name__ == '__main__':
parser.add_argument("--black_weight_path", type=str, default=None) parser.add_argument("--black_weight_path", type=str, default=None)
parser.add_argument("--white_weight_path", type=str, default=None) parser.add_argument("--white_weight_path", type=str, default=None)
parser.add_argument("--id", type=int, default=0) parser.add_argument("--id", type=int, default=0)
parser.add_argument("--debug", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
if not os.path.exists(args.result_path): if not os.path.exists(args.result_path):
@ -60,11 +61,13 @@ if __name__ == '__main__':
white_role_name = 'white' + str(args.id) white_role_name = 'white' + str(args.id)
agent_v0 = subprocess.Popen( agent_v0 = subprocess.Popen(
['python', '-u', 'player.py', '--role=' + black_role_name, '--checkpoint_path=' + str(args.black_weight_path)], ['python', '-u', 'player.py', '--role=' + black_role_name,
'--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)],
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
agent_v1 = subprocess.Popen( agent_v1 = subprocess.Popen(
['python', '-u', 'player.py', '--role=' + white_role_name, '--checkpoint_path=' + str(args.white_weight_path)], ['python', '-u', 'player.py', '--role=' + white_role_name,
'--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)],
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
server_list = "" server_list = ""
@ -92,7 +95,8 @@ if __name__ == '__main__':
evaluate_rounds = 1 evaluate_rounds = 1
game_num = 0 game_num = 0
try: try:
while True: #while True:
while game_num < evaluate_rounds:
start_time = time.time() start_time = time.time()
num = 0 num = 0
pass_flag = [False, False] pass_flag = [False, False]
@ -107,6 +111,7 @@ if __name__ == '__main__':
print show[board[i * size + j]] + " ", print show[board[i * size + j]] + " ",
print "\n", print "\n",
data.boards.append(board) data.boards.append(board)
start_time = time.time()
move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn] + '\n') move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn] + '\n')
print role[turn] + " : " + str(move), print role[turn] + " : " + str(move),
num += 1 num += 1

View File

@ -25,11 +25,15 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, default=None) parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--role", type=str, default="unknown") parser.add_argument("--role", type=str, default="unknown")
parser.add_argument("--debug", type=str, default=False)
args = parser.parse_args() args = parser.parse_args()
if args.checkpoint_path == 'None': if args.checkpoint_path == 'None':
args.checkpoint_path = None args.checkpoint_path = None
game = Game(checkpoint_path=args.checkpoint_path) debug = False
if args.debug == "True":
debug = True
game = Game(role=args.role, checkpoint_path=args.checkpoint_path, debug=debug)
engine = GTPEngine(game_obj=game, name='tianshou', version=0) engine = GTPEngine(game_obj=game, name='tianshou', version=0)
daemon = Pyro4.Daemon() # make a Pyro daemon daemon = Pyro4.Daemon() # make a Pyro daemon

View File

@ -40,16 +40,23 @@ class MCTSNode(object):
class UCTNode(MCTSNode): class UCTNode(MCTSNode):
def __init__(self, parent, action, state, action_num, prior, inverse=False): def __init__(self, parent, action, state, action_num, prior, debug=False, inverse=False):
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse) super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
self.Q = np.zeros([action_num]) self.Q = np.zeros([action_num])
self.W = np.zeros([action_num]) self.W = np.zeros([action_num])
self.N = np.zeros([action_num]) self.N = np.zeros([action_num])
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1) self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
self.mask = None self.mask = None
self.debug=debug
self.elapse_time = 0
def clear_elapse_time(self):
self.elapse_time = 0
def selection(self, simulator): def selection(self, simulator):
head = time.time()
self.valid_mask(simulator) self.valid_mask(simulator)
self.elapse_time += time.time() - head
action = np.argmax(self.ucb) action = np.argmax(self.ucb)
if action in self.children.keys(): if action in self.children.keys():
return self.children[action].selection(simulator) return self.children[action].selection(simulator)
@ -142,15 +149,18 @@ class ActionNode(object):
class MCTS(object): class MCTS(object):
def __init__(self, simulator, evaluator, root, action_num, method="UCT", inverse=False): def __init__(self, simulator, evaluator, root, action_num, method="UCT",
role="unknown", debug=False, inverse=False):
self.simulator = simulator self.simulator = simulator
self.evaluator = evaluator self.evaluator = evaluator
self.role = role
self.debug = debug
prior, _ = self.evaluator(root) prior, _ = self.evaluator(root)
self.action_num = action_num self.action_num = action_num
if method == "": if method == "":
self.root = root self.root = root
if method == "UCT": if method == "UCT":
self.root = UCTNode(None, None, root, action_num, prior, inverse=inverse) self.root = UCTNode(None, None, root, action_num, prior, self.debug, inverse=inverse)
if method == "TS": if method == "TS":
self.root = TSNode(None, None, root, action_num, prior, inverse=inverse) self.root = TSNode(None, None, root, action_num, prior, inverse=inverse)
self.inverse = inverse self.inverse = inverse
@ -165,14 +175,36 @@ class MCTS(object):
if max_step is None and max_time is None: if max_step is None and max_time is None:
raise ValueError("Need a stop criteria!") raise ValueError("Need a stop criteria!")
selection_time = 0
expansion_time = 0
backprop_time = 0
self.root.clear_elapse_time()
while step < max_step and time.time() - start_time < max_step: while step < max_step and time.time() - start_time < max_step:
self._expand() sel_time, exp_time, back_time = self._expand()
selection_time += sel_time
expansion_time += exp_time
backprop_time += back_time
step += 1 step += 1
if (self.debug):
file = open("debug.txt", "a")
file.write("[" + str(self.role) + "]"
+ " selection : " + str(selection_time) + "\t"
+ " validmask : " + str(self.root.elapse_time) + "\t"
+ " expansion : " + str(expansion_time) + "\t"
+ " backprop : " + str(backprop_time) + "\t"
+ "\n")
file.close()
def _expand(self): def _expand(self):
t0 = time.time()
node, new_action = self.root.selection(self.simulator) node, new_action = self.root.selection(self.simulator)
t1 = time.time()
value = node.children[new_action].expansion(self.evaluator, self.action_num) value = node.children[new_action].expansion(self.evaluator, self.action_num)
t2 = time.time()
node.children[new_action].backpropagation(value + 0.) node.children[new_action].backpropagation(value + 0.)
t3 = time.time()
return t1 - t0, t2 - t1, t3 - t2
if __name__ == "__main__": if __name__ == "__main__":
pass pass