add some code for debug and profiling
This commit is contained in:
parent
162aa313b6
commit
426251e158
@ -17,6 +17,7 @@ from tianshou.core.mcts.mcts import MCTS
|
|||||||
|
|
||||||
import go
|
import go
|
||||||
import reversi
|
import reversi
|
||||||
|
import time
|
||||||
|
|
||||||
class Game:
|
class Game:
|
||||||
'''
|
'''
|
||||||
@ -25,8 +26,10 @@ class Game:
|
|||||||
TODO : Maybe merge with the engine class in future,
|
TODO : Maybe merge with the engine class in future,
|
||||||
currently leave it untouched for interacting with Go UI.
|
currently leave it untouched for interacting with Go UI.
|
||||||
'''
|
'''
|
||||||
def __init__(self, name="go", checkpoint_path=None):
|
def __init__(self, name="go", role="unknown", debug=False, checkpoint_path=None):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.role = role
|
||||||
|
self.debug = debug
|
||||||
if self.name == "go":
|
if self.name == "go":
|
||||||
self.size = 9
|
self.size = 9
|
||||||
self.komi = 3.75
|
self.komi = 3.75
|
||||||
@ -36,7 +39,7 @@ class Game:
|
|||||||
self.latest_boards = deque(maxlen=8)
|
self.latest_boards = deque(maxlen=8)
|
||||||
for _ in range(8):
|
for _ in range(8):
|
||||||
self.latest_boards.append(self.board)
|
self.latest_boards.append(self.board)
|
||||||
self.game_engine = go.Go(size=self.size, komi=self.komi)
|
self.game_engine = go.Go(size=self.size, komi=self.komi, role=self.role)
|
||||||
elif self.name == "reversi":
|
elif self.name == "reversi":
|
||||||
self.size = 8
|
self.size = 8
|
||||||
self.history_length = 1
|
self.history_length = 1
|
||||||
@ -61,7 +64,8 @@ class Game:
|
|||||||
self.komi = k
|
self.komi = k
|
||||||
|
|
||||||
def think(self, latest_boards, color):
|
def think(self, latest_boards, color):
|
||||||
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], self.size ** 2 + 1, inverse=True)
|
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color],
|
||||||
|
self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True)
|
||||||
mcts.search(max_step=100)
|
mcts.search(max_step=100)
|
||||||
temp = 1
|
temp = 1
|
||||||
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
||||||
|
@ -18,6 +18,7 @@ class Go:
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.size = kwargs['size']
|
self.size = kwargs['size']
|
||||||
self.komi = kwargs['komi']
|
self.komi = kwargs['komi']
|
||||||
|
self.role = kwargs['role']
|
||||||
|
|
||||||
def _flatten(self, vertex):
|
def _flatten(self, vertex):
|
||||||
x, y = vertex
|
x, y = vertex
|
||||||
|
@ -152,6 +152,9 @@ class ResNet(object):
|
|||||||
:param color: a string, indicate which one to play
|
:param color: a string, indicate which one to play
|
||||||
:return: a list of tensor, the predicted value and policy given the history and color
|
:return: a list of tensor, the predicted value and policy given the history and color
|
||||||
"""
|
"""
|
||||||
|
# Note : maybe we can use it for isolating test of MCTS
|
||||||
|
#prob = [1.0 / self.action_num] * self.action_num
|
||||||
|
#return [prob, np.random.uniform(-1, 1)]
|
||||||
history, color = state
|
history, color = state
|
||||||
if len(history) != self.history_length:
|
if len(history) != self.history_length:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -28,6 +28,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--black_weight_path", type=str, default=None)
|
parser.add_argument("--black_weight_path", type=str, default=None)
|
||||||
parser.add_argument("--white_weight_path", type=str, default=None)
|
parser.add_argument("--white_weight_path", type=str, default=None)
|
||||||
parser.add_argument("--id", type=int, default=0)
|
parser.add_argument("--id", type=int, default=0)
|
||||||
|
parser.add_argument("--debug", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not os.path.exists(args.result_path):
|
if not os.path.exists(args.result_path):
|
||||||
@ -60,11 +61,13 @@ if __name__ == '__main__':
|
|||||||
white_role_name = 'white' + str(args.id)
|
white_role_name = 'white' + str(args.id)
|
||||||
|
|
||||||
agent_v0 = subprocess.Popen(
|
agent_v0 = subprocess.Popen(
|
||||||
['python', '-u', 'player.py', '--role=' + black_role_name, '--checkpoint_path=' + str(args.black_weight_path)],
|
['python', '-u', 'player.py', '--role=' + black_role_name,
|
||||||
|
'--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)],
|
||||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||||
|
|
||||||
agent_v1 = subprocess.Popen(
|
agent_v1 = subprocess.Popen(
|
||||||
['python', '-u', 'player.py', '--role=' + white_role_name, '--checkpoint_path=' + str(args.white_weight_path)],
|
['python', '-u', 'player.py', '--role=' + white_role_name,
|
||||||
|
'--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)],
|
||||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||||
|
|
||||||
server_list = ""
|
server_list = ""
|
||||||
@ -92,7 +95,8 @@ if __name__ == '__main__':
|
|||||||
evaluate_rounds = 1
|
evaluate_rounds = 1
|
||||||
game_num = 0
|
game_num = 0
|
||||||
try:
|
try:
|
||||||
while True:
|
#while True:
|
||||||
|
while game_num < evaluate_rounds:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
num = 0
|
num = 0
|
||||||
pass_flag = [False, False]
|
pass_flag = [False, False]
|
||||||
@ -107,6 +111,7 @@ if __name__ == '__main__':
|
|||||||
print show[board[i * size + j]] + " ",
|
print show[board[i * size + j]] + " ",
|
||||||
print "\n",
|
print "\n",
|
||||||
data.boards.append(board)
|
data.boards.append(board)
|
||||||
|
start_time = time.time()
|
||||||
move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn] + '\n')
|
move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn] + '\n')
|
||||||
print role[turn] + " : " + str(move),
|
print role[turn] + " : " + str(move),
|
||||||
num += 1
|
num += 1
|
||||||
|
@ -25,11 +25,15 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--checkpoint_path", type=str, default=None)
|
parser.add_argument("--checkpoint_path", type=str, default=None)
|
||||||
parser.add_argument("--role", type=str, default="unknown")
|
parser.add_argument("--role", type=str, default="unknown")
|
||||||
|
parser.add_argument("--debug", type=str, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.checkpoint_path == 'None':
|
if args.checkpoint_path == 'None':
|
||||||
args.checkpoint_path = None
|
args.checkpoint_path = None
|
||||||
game = Game(checkpoint_path=args.checkpoint_path)
|
debug = False
|
||||||
|
if args.debug == "True":
|
||||||
|
debug = True
|
||||||
|
game = Game(role=args.role, checkpoint_path=args.checkpoint_path, debug=debug)
|
||||||
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
|
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
|
||||||
|
|
||||||
daemon = Pyro4.Daemon() # make a Pyro daemon
|
daemon = Pyro4.Daemon() # make a Pyro daemon
|
||||||
|
@ -40,16 +40,23 @@ class MCTSNode(object):
|
|||||||
|
|
||||||
|
|
||||||
class UCTNode(MCTSNode):
|
class UCTNode(MCTSNode):
|
||||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
def __init__(self, parent, action, state, action_num, prior, debug=False, inverse=False):
|
||||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||||
self.Q = np.zeros([action_num])
|
self.Q = np.zeros([action_num])
|
||||||
self.W = np.zeros([action_num])
|
self.W = np.zeros([action_num])
|
||||||
self.N = np.zeros([action_num])
|
self.N = np.zeros([action_num])
|
||||||
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
||||||
self.mask = None
|
self.mask = None
|
||||||
|
self.debug=debug
|
||||||
|
self.elapse_time = 0
|
||||||
|
|
||||||
|
def clear_elapse_time(self):
|
||||||
|
self.elapse_time = 0
|
||||||
|
|
||||||
def selection(self, simulator):
|
def selection(self, simulator):
|
||||||
|
head = time.time()
|
||||||
self.valid_mask(simulator)
|
self.valid_mask(simulator)
|
||||||
|
self.elapse_time += time.time() - head
|
||||||
action = np.argmax(self.ucb)
|
action = np.argmax(self.ucb)
|
||||||
if action in self.children.keys():
|
if action in self.children.keys():
|
||||||
return self.children[action].selection(simulator)
|
return self.children[action].selection(simulator)
|
||||||
@ -142,15 +149,18 @@ class ActionNode(object):
|
|||||||
|
|
||||||
|
|
||||||
class MCTS(object):
|
class MCTS(object):
|
||||||
def __init__(self, simulator, evaluator, root, action_num, method="UCT", inverse=False):
|
def __init__(self, simulator, evaluator, root, action_num, method="UCT",
|
||||||
|
role="unknown", debug=False, inverse=False):
|
||||||
self.simulator = simulator
|
self.simulator = simulator
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
|
self.role = role
|
||||||
|
self.debug = debug
|
||||||
prior, _ = self.evaluator(root)
|
prior, _ = self.evaluator(root)
|
||||||
self.action_num = action_num
|
self.action_num = action_num
|
||||||
if method == "":
|
if method == "":
|
||||||
self.root = root
|
self.root = root
|
||||||
if method == "UCT":
|
if method == "UCT":
|
||||||
self.root = UCTNode(None, None, root, action_num, prior, inverse=inverse)
|
self.root = UCTNode(None, None, root, action_num, prior, self.debug, inverse=inverse)
|
||||||
if method == "TS":
|
if method == "TS":
|
||||||
self.root = TSNode(None, None, root, action_num, prior, inverse=inverse)
|
self.root = TSNode(None, None, root, action_num, prior, inverse=inverse)
|
||||||
self.inverse = inverse
|
self.inverse = inverse
|
||||||
@ -165,14 +175,36 @@ class MCTS(object):
|
|||||||
if max_step is None and max_time is None:
|
if max_step is None and max_time is None:
|
||||||
raise ValueError("Need a stop criteria!")
|
raise ValueError("Need a stop criteria!")
|
||||||
|
|
||||||
|
selection_time = 0
|
||||||
|
expansion_time = 0
|
||||||
|
backprop_time = 0
|
||||||
|
self.root.clear_elapse_time()
|
||||||
while step < max_step and time.time() - start_time < max_step:
|
while step < max_step and time.time() - start_time < max_step:
|
||||||
self._expand()
|
sel_time, exp_time, back_time = self._expand()
|
||||||
|
selection_time += sel_time
|
||||||
|
expansion_time += exp_time
|
||||||
|
backprop_time += back_time
|
||||||
step += 1
|
step += 1
|
||||||
|
if (self.debug):
|
||||||
|
file = open("debug.txt", "a")
|
||||||
|
file.write("[" + str(self.role) + "]"
|
||||||
|
+ " selection : " + str(selection_time) + "\t"
|
||||||
|
+ " validmask : " + str(self.root.elapse_time) + "\t"
|
||||||
|
+ " expansion : " + str(expansion_time) + "\t"
|
||||||
|
+ " backprop : " + str(backprop_time) + "\t"
|
||||||
|
+ "\n")
|
||||||
|
file.close()
|
||||||
|
|
||||||
def _expand(self):
|
def _expand(self):
|
||||||
|
t0 = time.time()
|
||||||
node, new_action = self.root.selection(self.simulator)
|
node, new_action = self.root.selection(self.simulator)
|
||||||
|
t1 = time.time()
|
||||||
value = node.children[new_action].expansion(self.evaluator, self.action_num)
|
value = node.children[new_action].expansion(self.evaluator, self.action_num)
|
||||||
|
t2 = time.time()
|
||||||
node.children[new_action].backpropagation(value + 0.)
|
node.children[new_action].backpropagation(value + 0.)
|
||||||
|
t3 = time.time()
|
||||||
|
return t1 - t0, t2 - t1, t3 - t2
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pass
|
pass
|
||||||
|
Loading…
x
Reference in New Issue
Block a user