add some code for debug and profiling
This commit is contained in:
parent
162aa313b6
commit
426251e158
@ -17,6 +17,7 @@ from tianshou.core.mcts.mcts import MCTS
|
||||
|
||||
import go
|
||||
import reversi
|
||||
import time
|
||||
|
||||
class Game:
|
||||
'''
|
||||
@ -25,8 +26,10 @@ class Game:
|
||||
TODO : Maybe merge with the engine class in future,
|
||||
currently leave it untouched for interacting with Go UI.
|
||||
'''
|
||||
def __init__(self, name="go", checkpoint_path=None):
|
||||
def __init__(self, name="go", role="unknown", debug=False, checkpoint_path=None):
|
||||
self.name = name
|
||||
self.role = role
|
||||
self.debug = debug
|
||||
if self.name == "go":
|
||||
self.size = 9
|
||||
self.komi = 3.75
|
||||
@ -36,7 +39,7 @@ class Game:
|
||||
self.latest_boards = deque(maxlen=8)
|
||||
for _ in range(8):
|
||||
self.latest_boards.append(self.board)
|
||||
self.game_engine = go.Go(size=self.size, komi=self.komi)
|
||||
self.game_engine = go.Go(size=self.size, komi=self.komi, role=self.role)
|
||||
elif self.name == "reversi":
|
||||
self.size = 8
|
||||
self.history_length = 1
|
||||
@ -61,7 +64,8 @@ class Game:
|
||||
self.komi = k
|
||||
|
||||
def think(self, latest_boards, color):
|
||||
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], self.size ** 2 + 1, inverse=True)
|
||||
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color],
|
||||
self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True)
|
||||
mcts.search(max_step=100)
|
||||
temp = 1
|
||||
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
||||
|
@ -18,6 +18,7 @@ class Go:
|
||||
def __init__(self, **kwargs):
|
||||
self.size = kwargs['size']
|
||||
self.komi = kwargs['komi']
|
||||
self.role = kwargs['role']
|
||||
|
||||
def _flatten(self, vertex):
|
||||
x, y = vertex
|
||||
|
@ -152,6 +152,9 @@ class ResNet(object):
|
||||
:param color: a string, indicate which one to play
|
||||
:return: a list of tensor, the predicted value and policy given the history and color
|
||||
"""
|
||||
# Note : maybe we can use it for isolating test of MCTS
|
||||
#prob = [1.0 / self.action_num] * self.action_num
|
||||
#return [prob, np.random.uniform(-1, 1)]
|
||||
history, color = state
|
||||
if len(history) != self.history_length:
|
||||
raise ValueError(
|
||||
|
@ -28,6 +28,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--black_weight_path", type=str, default=None)
|
||||
parser.add_argument("--white_weight_path", type=str, default=None)
|
||||
parser.add_argument("--id", type=int, default=0)
|
||||
parser.add_argument("--debug", type=bool, default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.result_path):
|
||||
@ -60,11 +61,13 @@ if __name__ == '__main__':
|
||||
white_role_name = 'white' + str(args.id)
|
||||
|
||||
agent_v0 = subprocess.Popen(
|
||||
['python', '-u', 'player.py', '--role=' + black_role_name, '--checkpoint_path=' + str(args.black_weight_path)],
|
||||
['python', '-u', 'player.py', '--role=' + black_role_name,
|
||||
'--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)],
|
||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
|
||||
agent_v1 = subprocess.Popen(
|
||||
['python', '-u', 'player.py', '--role=' + white_role_name, '--checkpoint_path=' + str(args.white_weight_path)],
|
||||
['python', '-u', 'player.py', '--role=' + white_role_name,
|
||||
'--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)],
|
||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
|
||||
server_list = ""
|
||||
@ -92,7 +95,8 @@ if __name__ == '__main__':
|
||||
evaluate_rounds = 1
|
||||
game_num = 0
|
||||
try:
|
||||
while True:
|
||||
#while True:
|
||||
while game_num < evaluate_rounds:
|
||||
start_time = time.time()
|
||||
num = 0
|
||||
pass_flag = [False, False]
|
||||
@ -107,6 +111,7 @@ if __name__ == '__main__':
|
||||
print show[board[i * size + j]] + " ",
|
||||
print "\n",
|
||||
data.boards.append(board)
|
||||
start_time = time.time()
|
||||
move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn] + '\n')
|
||||
print role[turn] + " : " + str(move),
|
||||
num += 1
|
||||
|
@ -25,11 +25,15 @@ if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None)
|
||||
parser.add_argument("--role", type=str, default="unknown")
|
||||
parser.add_argument("--debug", type=str, default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.checkpoint_path == 'None':
|
||||
args.checkpoint_path = None
|
||||
game = Game(checkpoint_path=args.checkpoint_path)
|
||||
debug = False
|
||||
if args.debug == "True":
|
||||
debug = True
|
||||
game = Game(role=args.role, checkpoint_path=args.checkpoint_path, debug=debug)
|
||||
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
|
||||
|
||||
daemon = Pyro4.Daemon() # make a Pyro daemon
|
||||
|
@ -40,16 +40,23 @@ class MCTSNode(object):
|
||||
|
||||
|
||||
class UCTNode(MCTSNode):
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||
def __init__(self, parent, action, state, action_num, prior, debug=False, inverse=False):
|
||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||
self.Q = np.zeros([action_num])
|
||||
self.W = np.zeros([action_num])
|
||||
self.N = np.zeros([action_num])
|
||||
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
||||
self.mask = None
|
||||
self.debug=debug
|
||||
self.elapse_time = 0
|
||||
|
||||
def clear_elapse_time(self):
|
||||
self.elapse_time = 0
|
||||
|
||||
def selection(self, simulator):
|
||||
head = time.time()
|
||||
self.valid_mask(simulator)
|
||||
self.elapse_time += time.time() - head
|
||||
action = np.argmax(self.ucb)
|
||||
if action in self.children.keys():
|
||||
return self.children[action].selection(simulator)
|
||||
@ -142,15 +149,18 @@ class ActionNode(object):
|
||||
|
||||
|
||||
class MCTS(object):
|
||||
def __init__(self, simulator, evaluator, root, action_num, method="UCT", inverse=False):
|
||||
def __init__(self, simulator, evaluator, root, action_num, method="UCT",
|
||||
role="unknown", debug=False, inverse=False):
|
||||
self.simulator = simulator
|
||||
self.evaluator = evaluator
|
||||
self.role = role
|
||||
self.debug = debug
|
||||
prior, _ = self.evaluator(root)
|
||||
self.action_num = action_num
|
||||
if method == "":
|
||||
self.root = root
|
||||
if method == "UCT":
|
||||
self.root = UCTNode(None, None, root, action_num, prior, inverse=inverse)
|
||||
self.root = UCTNode(None, None, root, action_num, prior, self.debug, inverse=inverse)
|
||||
if method == "TS":
|
||||
self.root = TSNode(None, None, root, action_num, prior, inverse=inverse)
|
||||
self.inverse = inverse
|
||||
@ -165,14 +175,36 @@ class MCTS(object):
|
||||
if max_step is None and max_time is None:
|
||||
raise ValueError("Need a stop criteria!")
|
||||
|
||||
selection_time = 0
|
||||
expansion_time = 0
|
||||
backprop_time = 0
|
||||
self.root.clear_elapse_time()
|
||||
while step < max_step and time.time() - start_time < max_step:
|
||||
self._expand()
|
||||
sel_time, exp_time, back_time = self._expand()
|
||||
selection_time += sel_time
|
||||
expansion_time += exp_time
|
||||
backprop_time += back_time
|
||||
step += 1
|
||||
if (self.debug):
|
||||
file = open("debug.txt", "a")
|
||||
file.write("[" + str(self.role) + "]"
|
||||
+ " selection : " + str(selection_time) + "\t"
|
||||
+ " validmask : " + str(self.root.elapse_time) + "\t"
|
||||
+ " expansion : " + str(expansion_time) + "\t"
|
||||
+ " backprop : " + str(backprop_time) + "\t"
|
||||
+ "\n")
|
||||
file.close()
|
||||
|
||||
def _expand(self):
|
||||
t0 = time.time()
|
||||
node, new_action = self.root.selection(self.simulator)
|
||||
t1 = time.time()
|
||||
value = node.children[new_action].expansion(self.evaluator, self.action_num)
|
||||
t2 = time.time()
|
||||
node.children[new_action].backpropagation(value + 0.)
|
||||
t3 = time.time()
|
||||
return t1 - t0, t2 - t1, t3 - t2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
Loading…
x
Reference in New Issue
Block a user