fix bug of game config and add profing functions to mcts
This commit is contained in:
parent
2d9aa32758
commit
f0074aa7ca
@ -198,5 +198,4 @@ class GTPEngine():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "main":
|
if __name__ == "main":
|
||||||
game = Game()
|
print ("test engine.py")
|
||||||
engine = GTPEngine(game_obj=game)
|
|
||||||
|
@ -26,7 +26,7 @@ class Game:
|
|||||||
TODO : Maybe merge with the engine class in future,
|
TODO : Maybe merge with the engine class in future,
|
||||||
currently leave it untouched for interacting with Go UI.
|
currently leave it untouched for interacting with Go UI.
|
||||||
'''
|
'''
|
||||||
def __init__(self, name="reversi", role="unknown", debug=False, checkpoint_path=None):
|
def __init__(self, name=None, role=None, debug=False, checkpoint_path=None):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.role = role
|
self.role = role
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
@ -119,10 +119,7 @@ class Game:
|
|||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
g = Game("go")
|
print("test game.py")
|
||||||
print(g.board)
|
|
||||||
g.clear()
|
|
||||||
g.think_play_move(1)
|
|
||||||
#file = open("debug.txt", "a")
|
#file = open("debug.txt", "a")
|
||||||
#file.write("mcts check\n")
|
#file.write("mcts check\n")
|
||||||
#file.close()
|
#file.close()
|
||||||
|
@ -60,13 +60,14 @@ if __name__ == '__main__':
|
|||||||
black_role_name = 'black' + str(args.id)
|
black_role_name = 'black' + str(args.id)
|
||||||
white_role_name = 'white' + str(args.id)
|
white_role_name = 'white' + str(args.id)
|
||||||
|
|
||||||
|
game_name = 'go'
|
||||||
agent_v0 = subprocess.Popen(
|
agent_v0 = subprocess.Popen(
|
||||||
['python', '-u', 'player.py', '--role=' + black_role_name,
|
['python', '-u', 'player.py', '--game=' + game_name, '--role=' + black_role_name,
|
||||||
'--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)],
|
'--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)],
|
||||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||||
|
|
||||||
agent_v1 = subprocess.Popen(
|
agent_v1 = subprocess.Popen(
|
||||||
['python', '-u', 'player.py', '--role=' + white_role_name,
|
['python', '-u', 'player.py', '--game=' + game_name, '--role=' + white_role_name,
|
||||||
'--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)],
|
'--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)],
|
||||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||||
|
|
||||||
@ -102,13 +103,13 @@ if __name__ == '__main__':
|
|||||||
pass_flag = [False, False]
|
pass_flag = [False, False]
|
||||||
print("Start game {}".format(game_num))
|
print("Start game {}".format(game_num))
|
||||||
# end the game if both palyer chose to pass, or play too much turns
|
# end the game if both palyer chose to pass, or play too much turns
|
||||||
while not (pass_flag[0] and pass_flag[1]) and num < size["reversi"] ** 2 * 2:
|
while not (pass_flag[0] and pass_flag[1]) and num < size[game_name] ** 2 * 2:
|
||||||
turn = num % 2
|
turn = num % 2
|
||||||
board = player[turn].run_cmd(str(num) + ' show_board')
|
board = player[turn].run_cmd(str(num) + ' show_board')
|
||||||
board = eval(board[board.index('['):board.index(']') + 1])
|
board = eval(board[board.index('['):board.index(']') + 1])
|
||||||
for i in range(size["reversi"]):
|
for i in range(size[game_name]):
|
||||||
for j in range(size["reversi"]):
|
for j in range(size[game_name]):
|
||||||
print show[board[i * size["reversi"] + j]] + " ",
|
print show[board[i * size[game_name] + j]] + " ",
|
||||||
print "\n",
|
print "\n",
|
||||||
data.boards.append(board)
|
data.boards.append(board)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -26,6 +26,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--checkpoint_path", type=str, default=None)
|
parser.add_argument("--checkpoint_path", type=str, default=None)
|
||||||
parser.add_argument("--role", type=str, default="unknown")
|
parser.add_argument("--role", type=str, default="unknown")
|
||||||
parser.add_argument("--debug", type=str, default=False)
|
parser.add_argument("--debug", type=str, default=False)
|
||||||
|
parser.add_argument("--game", type=str, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.checkpoint_path == 'None':
|
if args.checkpoint_path == 'None':
|
||||||
@ -33,7 +34,7 @@ if __name__ == '__main__':
|
|||||||
debug = False
|
debug = False
|
||||||
if args.debug == "True":
|
if args.debug == "True":
|
||||||
debug = True
|
debug = True
|
||||||
game = Game(role=args.role, checkpoint_path=args.checkpoint_path, debug=debug)
|
game = Game(name=args.game, role=args.role, checkpoint_path=args.checkpoint_path, debug=debug)
|
||||||
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
|
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
|
||||||
|
|
||||||
daemon = Pyro4.Daemon() # make a Pyro daemon
|
daemon = Pyro4.Daemon() # make a Pyro daemon
|
||||||
|
@ -1,123 +0,0 @@
|
|||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
import time
|
|
||||||
|
|
||||||
size = 9
|
|
||||||
path = "/raid/tongzheng/tianshou/AlphaGo/data/part1/"
|
|
||||||
save_path = "/raid/tongzheng/tianshou/AlphaGo/data/"
|
|
||||||
name = os.listdir(path)
|
|
||||||
print(len(name))
|
|
||||||
batch_size = 128
|
|
||||||
batch_num = 512
|
|
||||||
|
|
||||||
block_size = batch_size * batch_num
|
|
||||||
slots_num = 16
|
|
||||||
|
|
||||||
|
|
||||||
class block(object):
|
|
||||||
def __init__(self, block_size, block_id):
|
|
||||||
self.boards = []
|
|
||||||
self.wins = []
|
|
||||||
self.ps = []
|
|
||||||
self.block_size = block_size
|
|
||||||
self.block_id = block_id
|
|
||||||
|
|
||||||
def concat(self, board, p, win):
|
|
||||||
board = board.reshape(-1, size, size, 17)
|
|
||||||
win = win.reshape(-1, 1)
|
|
||||||
p = p.reshape(-1, size ** 2 + 1)
|
|
||||||
self.boards.append(board)
|
|
||||||
self.wins.append(win)
|
|
||||||
self.ps.append(p)
|
|
||||||
|
|
||||||
def isfull(self):
|
|
||||||
assert len(self.boards) == len(self.wins)
|
|
||||||
assert len(self.boards) == len(self.ps)
|
|
||||||
return len(self.boards) == self.block_size
|
|
||||||
|
|
||||||
def save_and_reset(self, block_id):
|
|
||||||
self.boards = np.concatenate(self.boards, axis=0)
|
|
||||||
self.wins = np.concatenate(self.wins, axis=0)
|
|
||||||
self.ps = np.concatenate(self.ps, axis=0)
|
|
||||||
print ("Block {}, Boards shape {}, Wins Shape {}, Ps Shape {}".format(self.block_id, self.boards.shape[0],
|
|
||||||
self.wins.shape[0], self.ps.shape[0]))
|
|
||||||
np.savez(save_path + "block" + str(self.block_id), boards=self.boards, wins=self.wins, ps=self.ps)
|
|
||||||
self.boards = []
|
|
||||||
self.wins = []
|
|
||||||
self.ps = []
|
|
||||||
self.block_id = block_id
|
|
||||||
|
|
||||||
def store_num(self):
|
|
||||||
assert len(self.boards) == len(self.wins)
|
|
||||||
assert len(self.boards) == len(self.ps)
|
|
||||||
return len(self.boards)
|
|
||||||
|
|
||||||
|
|
||||||
def concat(block_list, board, win, p):
|
|
||||||
global index
|
|
||||||
seed = np.random.randint(slots_num)
|
|
||||||
block_list[seed].concat(board, win, p)
|
|
||||||
if block_list[seed].isfull():
|
|
||||||
block_list[seed].save_and_reset(index)
|
|
||||||
index = index + 1
|
|
||||||
|
|
||||||
|
|
||||||
block_list = []
|
|
||||||
for index in range(slots_num):
|
|
||||||
block_list.append(block(block_size, index))
|
|
||||||
index = index + 1
|
|
||||||
for n in name:
|
|
||||||
data = np.load(path + n)
|
|
||||||
board = data["boards"]
|
|
||||||
win = data["win"]
|
|
||||||
p = data["p"]
|
|
||||||
print("Start {}".format(n))
|
|
||||||
print("Shape {}".format(board.shape[0]))
|
|
||||||
start = -time.time()
|
|
||||||
for i in range(board.shape[0]):
|
|
||||||
board_ori = board[i].reshape(-1, size, size, 17)
|
|
||||||
win_ori = win[i].reshape(-1, 1)
|
|
||||||
p_ori = p[i].reshape(-1, size ** 2 + 1)
|
|
||||||
concat(block_list, board_ori, p_ori, win_ori)
|
|
||||||
|
|
||||||
for t in range(1, 4):
|
|
||||||
board_aug = np.rot90(board_ori, t, (1, 2))
|
|
||||||
p_aug = np.concatenate(
|
|
||||||
[np.rot90(p_ori[:, :-1].reshape(-1, size, size), t, (1, 2)).reshape(-1, size ** 2), p_ori[:, -1].reshape(-1, 1)],
|
|
||||||
axis=1)
|
|
||||||
concat(block_list, board_aug, p_aug, win_ori)
|
|
||||||
|
|
||||||
board_aug = board_ori[:, ::-1]
|
|
||||||
p_aug = np.concatenate(
|
|
||||||
[p_ori[:, :-1].reshape(-1, size, size)[:, ::-1].reshape(-1, size ** 2), p_ori[:, -1].reshape(-1, 1)],
|
|
||||||
axis=1)
|
|
||||||
concat(block_list, board_aug, p_aug, win_ori)
|
|
||||||
|
|
||||||
board_aug = board_ori[:, :, ::-1]
|
|
||||||
p_aug = np.concatenate(
|
|
||||||
[p_ori[:, :-1].reshape(-1, size, size)[:, :, ::-1].reshape(-1, size ** 2), p_ori[:, -1].reshape(-1, 1)],
|
|
||||||
axis=1)
|
|
||||||
concat(block_list, board_aug, p_aug, win_ori)
|
|
||||||
|
|
||||||
board_aug = np.rot90(board_ori[:, ::-1], 1, (1, 2))
|
|
||||||
p_aug = np.concatenate(
|
|
||||||
[np.rot90(p_ori[:, :-1].reshape(-1, size, size)[:, ::-1], 1, (1, 2)).reshape(-1, size ** 2),
|
|
||||||
p_ori[:, -1].reshape(-1, 1)],
|
|
||||||
axis=1)
|
|
||||||
concat(block_list, board_aug, p_aug, win_ori)
|
|
||||||
|
|
||||||
board_aug = np.rot90(board_ori[:, :, ::-1], 1, (1, 2))
|
|
||||||
p_aug = np.concatenate(
|
|
||||||
[np.rot90(p_ori[:, :-1].reshape(-1, size, size)[:, :, ::-1], 1, (1, 2)).reshape(-1, size ** 2),
|
|
||||||
p_ori[:, -1].reshape(-1, 1)],
|
|
||||||
axis=1)
|
|
||||||
concat(block_list, board_aug, p_aug, win_ori)
|
|
||||||
print ("Finished {} with time {}".format(n, time.time() + start))
|
|
||||||
data_num = 0
|
|
||||||
for i in range(slots_num):
|
|
||||||
print("Block {} ".format(block_list[i].block_id) + "Size {}".format(block_list[i].store_num()))
|
|
||||||
data_num = data_num + block_list[i].store_num()
|
|
||||||
print ("Total data {}".format(data_num))
|
|
||||||
|
|
||||||
for i in range(slots_num):
|
|
||||||
block_list[i].save_and_reset(block_list[i].block_id)
|
|
@ -1,103 +0,0 @@
|
|||||||
from game import Game
|
|
||||||
from engine import GTPEngine
|
|
||||||
import re
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
from collections import deque
|
|
||||||
import utils
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--result_path', type=str, default='./part1')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if not os.path.exists(args.result_path):
|
|
||||||
os.makedirs(args.result_path)
|
|
||||||
|
|
||||||
game = Game()
|
|
||||||
engine = GTPEngine(game_obj=game)
|
|
||||||
history = deque(maxlen=8)
|
|
||||||
for i in range(8):
|
|
||||||
history.append(game.board)
|
|
||||||
state = []
|
|
||||||
prob = []
|
|
||||||
winner = []
|
|
||||||
pattern = "[A-Z]{1}[0-9]{1}"
|
|
||||||
game.show_board()
|
|
||||||
|
|
||||||
|
|
||||||
def history2state(history, color):
|
|
||||||
state = np.zeros([1, game.size, game.size, 17])
|
|
||||||
for i in range(8):
|
|
||||||
state[0, :, :, i] = np.array(np.array(history[i]) == np.ones(game.size ** 2)).reshape(game.size, game.size)
|
|
||||||
state[0, :, :, i + 8] = np.array(np.array(history[i]) == -np.ones(game.size ** 2)).reshape(game.size, game.size)
|
|
||||||
if color == utils.BLACK:
|
|
||||||
state[0, :, :, 16] = np.ones([game.size, game.size])
|
|
||||||
if color == utils.WHITE:
|
|
||||||
state[0, :, :, 16] = np.zeros([game.size, game.size])
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
num = 0
|
|
||||||
game_num = 0
|
|
||||||
black_pass = False
|
|
||||||
white_pass = False
|
|
||||||
while True:
|
|
||||||
print("Start game {}".format(game_num))
|
|
||||||
while not (black_pass and white_pass) and num < game.size ** 2 * 2:
|
|
||||||
if num % 2 == 0:
|
|
||||||
color = utils.BLACK
|
|
||||||
new_state = history2state(history, color)
|
|
||||||
state.append(new_state)
|
|
||||||
result = engine.run_cmd(str(num) + " genmove BLACK")
|
|
||||||
num += 1
|
|
||||||
match = re.search(pattern, result)
|
|
||||||
if match is not None:
|
|
||||||
print(match.group())
|
|
||||||
else:
|
|
||||||
print("pass")
|
|
||||||
if re.search("pass", result) is not None:
|
|
||||||
black_pass = True
|
|
||||||
else:
|
|
||||||
black_pass = False
|
|
||||||
else:
|
|
||||||
color = utils.WHITE
|
|
||||||
new_state = history2state(history, color)
|
|
||||||
state.append(new_state)
|
|
||||||
result = engine.run_cmd(str(num) + " genmove WHITE")
|
|
||||||
num += 1
|
|
||||||
match = re.search(pattern, result)
|
|
||||||
if match is not None:
|
|
||||||
print(match.group())
|
|
||||||
else:
|
|
||||||
print("pass")
|
|
||||||
if re.search("pass", result) is not None:
|
|
||||||
white_pass = True
|
|
||||||
else:
|
|
||||||
white_pass = False
|
|
||||||
game.show_board()
|
|
||||||
prob.append(np.array(game.prob).reshape(-1, game.size ** 2 + 1))
|
|
||||||
print("Finished")
|
|
||||||
print("\n")
|
|
||||||
score = game.game_engine.executor_get_score(game.board)
|
|
||||||
if score > 0:
|
|
||||||
winner = utils.BLACK
|
|
||||||
else:
|
|
||||||
winner = utils.WHITE
|
|
||||||
state = np.concatenate(state, axis=0)
|
|
||||||
prob = np.concatenate(prob, axis=0)
|
|
||||||
winner = np.ones([num, 1]) * winner
|
|
||||||
assert state.shape[0] == prob.shape[0]
|
|
||||||
assert state.shape[0] == winner.shape[0]
|
|
||||||
np.savez(args.result_path + "/game" + str(game_num), state=state, prob=prob, winner=winner)
|
|
||||||
state = []
|
|
||||||
prob = []
|
|
||||||
winner = []
|
|
||||||
num = 0
|
|
||||||
black_pass = False
|
|
||||||
white_pass = False
|
|
||||||
engine.run_cmd(str(num) + " clear_board")
|
|
||||||
history.clear()
|
|
||||||
for _ in range(8):
|
|
||||||
history.append(game.board)
|
|
||||||
game_num += 1
|
|
@ -40,28 +40,27 @@ class MCTSNode(object):
|
|||||||
|
|
||||||
|
|
||||||
class UCTNode(MCTSNode):
|
class UCTNode(MCTSNode):
|
||||||
def __init__(self, parent, action, state, action_num, prior, debug=False, inverse=False):
|
def __init__(self, parent, action, state, action_num, prior, mcts, inverse=False):
|
||||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||||
self.Q = np.zeros([action_num])
|
self.Q = np.zeros([action_num])
|
||||||
self.W = np.zeros([action_num])
|
self.W = np.zeros([action_num])
|
||||||
self.N = np.zeros([action_num])
|
self.N = np.zeros([action_num])
|
||||||
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.debug=debug
|
|
||||||
self.elapse_time = 0
|
|
||||||
|
|
||||||
def clear_elapse_time(self):
|
|
||||||
self.elapse_time = 0
|
self.elapse_time = 0
|
||||||
|
self.mcts = mcts
|
||||||
|
|
||||||
def selection(self, simulator):
|
def selection(self, simulator):
|
||||||
head = time.time()
|
head = time.time()
|
||||||
self.valid_mask(simulator)
|
self.valid_mask(simulator)
|
||||||
self.elapse_time += time.time() - head
|
self.mcts.valid_mask_time += time.time() - head
|
||||||
action = np.argmax(self.ucb)
|
action = np.argmax(self.ucb)
|
||||||
if action in self.children.keys():
|
if action in self.children.keys():
|
||||||
|
self.mcts.state_selection_time += time.time() - head
|
||||||
return self.children[action].selection(simulator)
|
return self.children[action].selection(simulator)
|
||||||
else:
|
else:
|
||||||
self.children[action] = ActionNode(self, action)
|
self.children[action] = ActionNode(self, action, mcts=self.mcts)
|
||||||
|
self.mcts.state_selection_time += time.time() - head
|
||||||
return self.children[action].selection(simulator)
|
return self.children[action].selection(simulator)
|
||||||
|
|
||||||
def backpropagation(self, action):
|
def backpropagation(self, action):
|
||||||
@ -100,7 +99,7 @@ class TSNode(MCTSNode):
|
|||||||
|
|
||||||
|
|
||||||
class ActionNode(object):
|
class ActionNode(object):
|
||||||
def __init__(self, parent, action):
|
def __init__(self, parent, action, mcts):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.action = action
|
self.action = action
|
||||||
self.children = {}
|
self.children = {}
|
||||||
@ -108,12 +107,18 @@ class ActionNode(object):
|
|||||||
self.origin_state = None
|
self.origin_state = None
|
||||||
self.state_type = None
|
self.state_type = None
|
||||||
self.reward = 0
|
self.reward = 0
|
||||||
|
self.mcts = mcts
|
||||||
|
|
||||||
def type_conversion_to_tuple(self):
|
def type_conversion_to_tuple(self):
|
||||||
|
t0 = time.time()
|
||||||
if isinstance(self.next_state, np.ndarray):
|
if isinstance(self.next_state, np.ndarray):
|
||||||
self.next_state = self.next_state.tolist()
|
self.next_state = self.next_state.tolist()
|
||||||
|
t1 = time.time()
|
||||||
if isinstance(self.next_state, list):
|
if isinstance(self.next_state, list):
|
||||||
self.next_state = list2tuple(self.next_state)
|
self.next_state = list2tuple(self.next_state)
|
||||||
|
t2 = time.time()
|
||||||
|
self.mcts.ndarray2list_time += t1 - t0
|
||||||
|
self.mcts.list2tuple_time += t2 - t1
|
||||||
|
|
||||||
def type_conversion_to_origin(self):
|
def type_conversion_to_origin(self):
|
||||||
if isinstance(self.state_type, np.ndarray):
|
if isinstance(self.state_type, np.ndarray):
|
||||||
@ -122,23 +127,28 @@ class ActionNode(object):
|
|||||||
self.next_state = tuple2list(self.next_state)
|
self.next_state = tuple2list(self.next_state)
|
||||||
|
|
||||||
def selection(self, simulator):
|
def selection(self, simulator):
|
||||||
|
head = time.time()
|
||||||
self.next_state, self.reward = simulator.simulate_step_forward(self.parent.state, self.action)
|
self.next_state, self.reward = simulator.simulate_step_forward(self.parent.state, self.action)
|
||||||
|
self.mcts.simulate_sf_time += time.time() - head
|
||||||
self.origin_state = self.next_state
|
self.origin_state = self.next_state
|
||||||
self.state_type = type(self.next_state)
|
self.state_type = type(self.next_state)
|
||||||
self.type_conversion_to_tuple()
|
self.type_conversion_to_tuple()
|
||||||
if self.next_state is not None:
|
if self.next_state is not None:
|
||||||
if self.next_state in self.children.keys():
|
if self.next_state in self.children.keys():
|
||||||
|
self.mcts.action_selection_time += time.time() - head
|
||||||
return self.children[self.next_state].selection(simulator)
|
return self.children[self.next_state].selection(simulator)
|
||||||
else:
|
else:
|
||||||
|
self.mcts.action_selection_time += time.time() - head
|
||||||
return self.parent, self.action
|
return self.parent, self.action
|
||||||
else:
|
else:
|
||||||
|
self.mcts.action_selection_time += time.time() - head
|
||||||
return self.parent, self.action
|
return self.parent, self.action
|
||||||
|
|
||||||
def expansion(self, evaluator, action_num):
|
def expansion(self, evaluator, action_num):
|
||||||
if self.next_state is not None:
|
if self.next_state is not None:
|
||||||
prior, value = evaluator(self.next_state)
|
prior, value = evaluator(self.next_state)
|
||||||
self.children[self.next_state] = UCTNode(self, self.action, self.origin_state, action_num, prior,
|
self.children[self.next_state] = UCTNode(self, self.action, self.origin_state, action_num, prior,
|
||||||
self.parent.inverse)
|
mcts=self.mcts, inverse=self.parent.inverse)
|
||||||
return value
|
return value
|
||||||
else:
|
else:
|
||||||
return 0.
|
return 0.
|
||||||
@ -160,11 +170,23 @@ class MCTS(object):
|
|||||||
if method == "":
|
if method == "":
|
||||||
self.root = root
|
self.root = root
|
||||||
if method == "UCT":
|
if method == "UCT":
|
||||||
self.root = UCTNode(None, None, root, action_num, prior, self.debug, inverse=inverse)
|
self.root = UCTNode(None, None, root, action_num, prior, mcts=self, inverse=inverse)
|
||||||
if method == "TS":
|
if method == "TS":
|
||||||
self.root = TSNode(None, None, root, action_num, prior, inverse=inverse)
|
self.root = TSNode(None, None, root, action_num, prior, inverse=inverse)
|
||||||
self.inverse = inverse
|
self.inverse = inverse
|
||||||
|
|
||||||
|
# time spend on each step
|
||||||
|
self.selection_time = 0
|
||||||
|
self.expansion_time = 0
|
||||||
|
self.backpropagation_time = 0
|
||||||
|
self.action_selection_time = 0
|
||||||
|
self.state_selection_time = 0
|
||||||
|
self.simulate_sf_time = 0
|
||||||
|
self.valid_mask_time = 0
|
||||||
|
self.ndarray2list_time = 0
|
||||||
|
self.list2tuple_time = 0
|
||||||
|
self.check = 0
|
||||||
|
|
||||||
def search(self, max_step=None, max_time=None):
|
def search(self, max_step=None, max_time=None):
|
||||||
step = 0
|
step = 0
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -175,23 +197,25 @@ class MCTS(object):
|
|||||||
if max_step is None and max_time is None:
|
if max_step is None and max_time is None:
|
||||||
raise ValueError("Need a stop criteria!")
|
raise ValueError("Need a stop criteria!")
|
||||||
|
|
||||||
selection_time = 0
|
|
||||||
expansion_time = 0
|
|
||||||
backprop_time = 0
|
|
||||||
self.root.clear_elapse_time()
|
|
||||||
while step < max_step and time.time() - start_time < max_step:
|
while step < max_step and time.time() - start_time < max_step:
|
||||||
sel_time, exp_time, back_time = self._expand()
|
sel_time, exp_time, back_time = self._expand()
|
||||||
selection_time += sel_time
|
self.selection_time += sel_time
|
||||||
expansion_time += exp_time
|
self.expansion_time += exp_time
|
||||||
backprop_time += back_time
|
self.backpropagation_time += back_time
|
||||||
step += 1
|
step += 1
|
||||||
if (self.debug):
|
if (self.debug):
|
||||||
file = open("debug.txt", "a")
|
file = open("mcts_profiling.txt", "a")
|
||||||
file.write("[" + str(self.role) + "]"
|
file.write("[" + str(self.role) + "]"
|
||||||
+ " selection : " + str(selection_time) + "\t"
|
+ " sel " + '%.3f' % self.selection_time + " "
|
||||||
+ " validmask : " + str(self.root.elapse_time) + "\t"
|
+ " sel_sta " + '%.3f' % self.state_selection_time + " "
|
||||||
+ " expansion : " + str(expansion_time) + "\t"
|
+ " valid " + '%.3f' % self.valid_mask_time + " "
|
||||||
+ " backprop : " + str(backprop_time) + "\t"
|
+ " sel_act " + '%.3f' % self.action_selection_time + " "
|
||||||
|
+ " array2list " + '%.4f' % self.ndarray2list_time + " "
|
||||||
|
+ " check " + str(self.check) + " "
|
||||||
|
+ " list2tuple " + '%.4f' % self.list2tuple_time + " \t"
|
||||||
|
+ " forward " + '%.3f' % self.simulate_sf_time + " "
|
||||||
|
+ " exp " + '%.3f' % self.expansion_time + " "
|
||||||
|
+ " bak " + '%.3f' % self.backpropagation_time + " "
|
||||||
+ "\n")
|
+ "\n")
|
||||||
file.close()
|
file.close()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user