fix bug of game config and add profing functions to mcts
This commit is contained in:
parent
2d9aa32758
commit
f0074aa7ca
@ -198,5 +198,4 @@ class GTPEngine():
|
||||
|
||||
|
||||
if __name__ == "main":
|
||||
game = Game()
|
||||
engine = GTPEngine(game_obj=game)
|
||||
print ("test engine.py")
|
||||
|
@ -26,7 +26,7 @@ class Game:
|
||||
TODO : Maybe merge with the engine class in future,
|
||||
currently leave it untouched for interacting with Go UI.
|
||||
'''
|
||||
def __init__(self, name="reversi", role="unknown", debug=False, checkpoint_path=None):
|
||||
def __init__(self, name=None, role=None, debug=False, checkpoint_path=None):
|
||||
self.name = name
|
||||
self.role = role
|
||||
self.debug = debug
|
||||
@ -119,10 +119,7 @@ class Game:
|
||||
sys.stdout.flush()
|
||||
|
||||
if __name__ == "__main__":
|
||||
g = Game("go")
|
||||
print(g.board)
|
||||
g.clear()
|
||||
g.think_play_move(1)
|
||||
print("test game.py")
|
||||
#file = open("debug.txt", "a")
|
||||
#file.write("mcts check\n")
|
||||
#file.close()
|
||||
|
@ -60,13 +60,14 @@ if __name__ == '__main__':
|
||||
black_role_name = 'black' + str(args.id)
|
||||
white_role_name = 'white' + str(args.id)
|
||||
|
||||
game_name = 'go'
|
||||
agent_v0 = subprocess.Popen(
|
||||
['python', '-u', 'player.py', '--role=' + black_role_name,
|
||||
['python', '-u', 'player.py', '--game=' + game_name, '--role=' + black_role_name,
|
||||
'--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)],
|
||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
|
||||
agent_v1 = subprocess.Popen(
|
||||
['python', '-u', 'player.py', '--role=' + white_role_name,
|
||||
['python', '-u', 'player.py', '--game=' + game_name, '--role=' + white_role_name,
|
||||
'--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)],
|
||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
|
||||
@ -102,13 +103,13 @@ if __name__ == '__main__':
|
||||
pass_flag = [False, False]
|
||||
print("Start game {}".format(game_num))
|
||||
# end the game if both palyer chose to pass, or play too much turns
|
||||
while not (pass_flag[0] and pass_flag[1]) and num < size["reversi"] ** 2 * 2:
|
||||
while not (pass_flag[0] and pass_flag[1]) and num < size[game_name] ** 2 * 2:
|
||||
turn = num % 2
|
||||
board = player[turn].run_cmd(str(num) + ' show_board')
|
||||
board = eval(board[board.index('['):board.index(']') + 1])
|
||||
for i in range(size["reversi"]):
|
||||
for j in range(size["reversi"]):
|
||||
print show[board[i * size["reversi"] + j]] + " ",
|
||||
for i in range(size[game_name]):
|
||||
for j in range(size[game_name]):
|
||||
print show[board[i * size[game_name] + j]] + " ",
|
||||
print "\n",
|
||||
data.boards.append(board)
|
||||
start_time = time.time()
|
||||
|
@ -26,6 +26,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None)
|
||||
parser.add_argument("--role", type=str, default="unknown")
|
||||
parser.add_argument("--debug", type=str, default=False)
|
||||
parser.add_argument("--game", type=str, default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.checkpoint_path == 'None':
|
||||
@ -33,7 +34,7 @@ if __name__ == '__main__':
|
||||
debug = False
|
||||
if args.debug == "True":
|
||||
debug = True
|
||||
game = Game(role=args.role, checkpoint_path=args.checkpoint_path, debug=debug)
|
||||
game = Game(name=args.game, role=args.role, checkpoint_path=args.checkpoint_path, debug=debug)
|
||||
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
|
||||
|
||||
daemon = Pyro4.Daemon() # make a Pyro daemon
|
||||
|
@ -1,123 +0,0 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
size = 9
|
||||
path = "/raid/tongzheng/tianshou/AlphaGo/data/part1/"
|
||||
save_path = "/raid/tongzheng/tianshou/AlphaGo/data/"
|
||||
name = os.listdir(path)
|
||||
print(len(name))
|
||||
batch_size = 128
|
||||
batch_num = 512
|
||||
|
||||
block_size = batch_size * batch_num
|
||||
slots_num = 16
|
||||
|
||||
|
||||
class block(object):
|
||||
def __init__(self, block_size, block_id):
|
||||
self.boards = []
|
||||
self.wins = []
|
||||
self.ps = []
|
||||
self.block_size = block_size
|
||||
self.block_id = block_id
|
||||
|
||||
def concat(self, board, p, win):
|
||||
board = board.reshape(-1, size, size, 17)
|
||||
win = win.reshape(-1, 1)
|
||||
p = p.reshape(-1, size ** 2 + 1)
|
||||
self.boards.append(board)
|
||||
self.wins.append(win)
|
||||
self.ps.append(p)
|
||||
|
||||
def isfull(self):
|
||||
assert len(self.boards) == len(self.wins)
|
||||
assert len(self.boards) == len(self.ps)
|
||||
return len(self.boards) == self.block_size
|
||||
|
||||
def save_and_reset(self, block_id):
|
||||
self.boards = np.concatenate(self.boards, axis=0)
|
||||
self.wins = np.concatenate(self.wins, axis=0)
|
||||
self.ps = np.concatenate(self.ps, axis=0)
|
||||
print ("Block {}, Boards shape {}, Wins Shape {}, Ps Shape {}".format(self.block_id, self.boards.shape[0],
|
||||
self.wins.shape[0], self.ps.shape[0]))
|
||||
np.savez(save_path + "block" + str(self.block_id), boards=self.boards, wins=self.wins, ps=self.ps)
|
||||
self.boards = []
|
||||
self.wins = []
|
||||
self.ps = []
|
||||
self.block_id = block_id
|
||||
|
||||
def store_num(self):
|
||||
assert len(self.boards) == len(self.wins)
|
||||
assert len(self.boards) == len(self.ps)
|
||||
return len(self.boards)
|
||||
|
||||
|
||||
def concat(block_list, board, win, p):
|
||||
global index
|
||||
seed = np.random.randint(slots_num)
|
||||
block_list[seed].concat(board, win, p)
|
||||
if block_list[seed].isfull():
|
||||
block_list[seed].save_and_reset(index)
|
||||
index = index + 1
|
||||
|
||||
|
||||
block_list = []
|
||||
for index in range(slots_num):
|
||||
block_list.append(block(block_size, index))
|
||||
index = index + 1
|
||||
for n in name:
|
||||
data = np.load(path + n)
|
||||
board = data["boards"]
|
||||
win = data["win"]
|
||||
p = data["p"]
|
||||
print("Start {}".format(n))
|
||||
print("Shape {}".format(board.shape[0]))
|
||||
start = -time.time()
|
||||
for i in range(board.shape[0]):
|
||||
board_ori = board[i].reshape(-1, size, size, 17)
|
||||
win_ori = win[i].reshape(-1, 1)
|
||||
p_ori = p[i].reshape(-1, size ** 2 + 1)
|
||||
concat(block_list, board_ori, p_ori, win_ori)
|
||||
|
||||
for t in range(1, 4):
|
||||
board_aug = np.rot90(board_ori, t, (1, 2))
|
||||
p_aug = np.concatenate(
|
||||
[np.rot90(p_ori[:, :-1].reshape(-1, size, size), t, (1, 2)).reshape(-1, size ** 2), p_ori[:, -1].reshape(-1, 1)],
|
||||
axis=1)
|
||||
concat(block_list, board_aug, p_aug, win_ori)
|
||||
|
||||
board_aug = board_ori[:, ::-1]
|
||||
p_aug = np.concatenate(
|
||||
[p_ori[:, :-1].reshape(-1, size, size)[:, ::-1].reshape(-1, size ** 2), p_ori[:, -1].reshape(-1, 1)],
|
||||
axis=1)
|
||||
concat(block_list, board_aug, p_aug, win_ori)
|
||||
|
||||
board_aug = board_ori[:, :, ::-1]
|
||||
p_aug = np.concatenate(
|
||||
[p_ori[:, :-1].reshape(-1, size, size)[:, :, ::-1].reshape(-1, size ** 2), p_ori[:, -1].reshape(-1, 1)],
|
||||
axis=1)
|
||||
concat(block_list, board_aug, p_aug, win_ori)
|
||||
|
||||
board_aug = np.rot90(board_ori[:, ::-1], 1, (1, 2))
|
||||
p_aug = np.concatenate(
|
||||
[np.rot90(p_ori[:, :-1].reshape(-1, size, size)[:, ::-1], 1, (1, 2)).reshape(-1, size ** 2),
|
||||
p_ori[:, -1].reshape(-1, 1)],
|
||||
axis=1)
|
||||
concat(block_list, board_aug, p_aug, win_ori)
|
||||
|
||||
board_aug = np.rot90(board_ori[:, :, ::-1], 1, (1, 2))
|
||||
p_aug = np.concatenate(
|
||||
[np.rot90(p_ori[:, :-1].reshape(-1, size, size)[:, :, ::-1], 1, (1, 2)).reshape(-1, size ** 2),
|
||||
p_ori[:, -1].reshape(-1, 1)],
|
||||
axis=1)
|
||||
concat(block_list, board_aug, p_aug, win_ori)
|
||||
print ("Finished {} with time {}".format(n, time.time() + start))
|
||||
data_num = 0
|
||||
for i in range(slots_num):
|
||||
print("Block {} ".format(block_list[i].block_id) + "Size {}".format(block_list[i].store_num()))
|
||||
data_num = data_num + block_list[i].store_num()
|
||||
print ("Total data {}".format(data_num))
|
||||
|
||||
for i in range(slots_num):
|
||||
block_list[i].save_and_reset(block_list[i].block_id)
|
@ -1,103 +0,0 @@
|
||||
from game import Game
|
||||
from engine import GTPEngine
|
||||
import re
|
||||
import numpy as np
|
||||
import os
|
||||
from collections import deque
|
||||
import utils
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--result_path', type=str, default='./part1')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.result_path):
|
||||
os.makedirs(args.result_path)
|
||||
|
||||
game = Game()
|
||||
engine = GTPEngine(game_obj=game)
|
||||
history = deque(maxlen=8)
|
||||
for i in range(8):
|
||||
history.append(game.board)
|
||||
state = []
|
||||
prob = []
|
||||
winner = []
|
||||
pattern = "[A-Z]{1}[0-9]{1}"
|
||||
game.show_board()
|
||||
|
||||
|
||||
def history2state(history, color):
|
||||
state = np.zeros([1, game.size, game.size, 17])
|
||||
for i in range(8):
|
||||
state[0, :, :, i] = np.array(np.array(history[i]) == np.ones(game.size ** 2)).reshape(game.size, game.size)
|
||||
state[0, :, :, i + 8] = np.array(np.array(history[i]) == -np.ones(game.size ** 2)).reshape(game.size, game.size)
|
||||
if color == utils.BLACK:
|
||||
state[0, :, :, 16] = np.ones([game.size, game.size])
|
||||
if color == utils.WHITE:
|
||||
state[0, :, :, 16] = np.zeros([game.size, game.size])
|
||||
return state
|
||||
|
||||
|
||||
num = 0
|
||||
game_num = 0
|
||||
black_pass = False
|
||||
white_pass = False
|
||||
while True:
|
||||
print("Start game {}".format(game_num))
|
||||
while not (black_pass and white_pass) and num < game.size ** 2 * 2:
|
||||
if num % 2 == 0:
|
||||
color = utils.BLACK
|
||||
new_state = history2state(history, color)
|
||||
state.append(new_state)
|
||||
result = engine.run_cmd(str(num) + " genmove BLACK")
|
||||
num += 1
|
||||
match = re.search(pattern, result)
|
||||
if match is not None:
|
||||
print(match.group())
|
||||
else:
|
||||
print("pass")
|
||||
if re.search("pass", result) is not None:
|
||||
black_pass = True
|
||||
else:
|
||||
black_pass = False
|
||||
else:
|
||||
color = utils.WHITE
|
||||
new_state = history2state(history, color)
|
||||
state.append(new_state)
|
||||
result = engine.run_cmd(str(num) + " genmove WHITE")
|
||||
num += 1
|
||||
match = re.search(pattern, result)
|
||||
if match is not None:
|
||||
print(match.group())
|
||||
else:
|
||||
print("pass")
|
||||
if re.search("pass", result) is not None:
|
||||
white_pass = True
|
||||
else:
|
||||
white_pass = False
|
||||
game.show_board()
|
||||
prob.append(np.array(game.prob).reshape(-1, game.size ** 2 + 1))
|
||||
print("Finished")
|
||||
print("\n")
|
||||
score = game.game_engine.executor_get_score(game.board)
|
||||
if score > 0:
|
||||
winner = utils.BLACK
|
||||
else:
|
||||
winner = utils.WHITE
|
||||
state = np.concatenate(state, axis=0)
|
||||
prob = np.concatenate(prob, axis=0)
|
||||
winner = np.ones([num, 1]) * winner
|
||||
assert state.shape[0] == prob.shape[0]
|
||||
assert state.shape[0] == winner.shape[0]
|
||||
np.savez(args.result_path + "/game" + str(game_num), state=state, prob=prob, winner=winner)
|
||||
state = []
|
||||
prob = []
|
||||
winner = []
|
||||
num = 0
|
||||
black_pass = False
|
||||
white_pass = False
|
||||
engine.run_cmd(str(num) + " clear_board")
|
||||
history.clear()
|
||||
for _ in range(8):
|
||||
history.append(game.board)
|
||||
game_num += 1
|
@ -40,28 +40,27 @@ class MCTSNode(object):
|
||||
|
||||
|
||||
class UCTNode(MCTSNode):
|
||||
def __init__(self, parent, action, state, action_num, prior, debug=False, inverse=False):
|
||||
def __init__(self, parent, action, state, action_num, prior, mcts, inverse=False):
|
||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||
self.Q = np.zeros([action_num])
|
||||
self.W = np.zeros([action_num])
|
||||
self.N = np.zeros([action_num])
|
||||
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
||||
self.mask = None
|
||||
self.debug=debug
|
||||
self.elapse_time = 0
|
||||
|
||||
def clear_elapse_time(self):
|
||||
self.elapse_time = 0
|
||||
self.mcts = mcts
|
||||
|
||||
def selection(self, simulator):
|
||||
head = time.time()
|
||||
self.valid_mask(simulator)
|
||||
self.elapse_time += time.time() - head
|
||||
self.mcts.valid_mask_time += time.time() - head
|
||||
action = np.argmax(self.ucb)
|
||||
if action in self.children.keys():
|
||||
self.mcts.state_selection_time += time.time() - head
|
||||
return self.children[action].selection(simulator)
|
||||
else:
|
||||
self.children[action] = ActionNode(self, action)
|
||||
self.children[action] = ActionNode(self, action, mcts=self.mcts)
|
||||
self.mcts.state_selection_time += time.time() - head
|
||||
return self.children[action].selection(simulator)
|
||||
|
||||
def backpropagation(self, action):
|
||||
@ -100,7 +99,7 @@ class TSNode(MCTSNode):
|
||||
|
||||
|
||||
class ActionNode(object):
|
||||
def __init__(self, parent, action):
|
||||
def __init__(self, parent, action, mcts):
|
||||
self.parent = parent
|
||||
self.action = action
|
||||
self.children = {}
|
||||
@ -108,12 +107,18 @@ class ActionNode(object):
|
||||
self.origin_state = None
|
||||
self.state_type = None
|
||||
self.reward = 0
|
||||
self.mcts = mcts
|
||||
|
||||
def type_conversion_to_tuple(self):
|
||||
t0 = time.time()
|
||||
if isinstance(self.next_state, np.ndarray):
|
||||
self.next_state = self.next_state.tolist()
|
||||
t1 = time.time()
|
||||
if isinstance(self.next_state, list):
|
||||
self.next_state = list2tuple(self.next_state)
|
||||
t2 = time.time()
|
||||
self.mcts.ndarray2list_time += t1 - t0
|
||||
self.mcts.list2tuple_time += t2 - t1
|
||||
|
||||
def type_conversion_to_origin(self):
|
||||
if isinstance(self.state_type, np.ndarray):
|
||||
@ -122,23 +127,28 @@ class ActionNode(object):
|
||||
self.next_state = tuple2list(self.next_state)
|
||||
|
||||
def selection(self, simulator):
|
||||
head = time.time()
|
||||
self.next_state, self.reward = simulator.simulate_step_forward(self.parent.state, self.action)
|
||||
self.mcts.simulate_sf_time += time.time() - head
|
||||
self.origin_state = self.next_state
|
||||
self.state_type = type(self.next_state)
|
||||
self.type_conversion_to_tuple()
|
||||
if self.next_state is not None:
|
||||
if self.next_state in self.children.keys():
|
||||
self.mcts.action_selection_time += time.time() - head
|
||||
return self.children[self.next_state].selection(simulator)
|
||||
else:
|
||||
self.mcts.action_selection_time += time.time() - head
|
||||
return self.parent, self.action
|
||||
else:
|
||||
self.mcts.action_selection_time += time.time() - head
|
||||
return self.parent, self.action
|
||||
|
||||
def expansion(self, evaluator, action_num):
|
||||
if self.next_state is not None:
|
||||
prior, value = evaluator(self.next_state)
|
||||
self.children[self.next_state] = UCTNode(self, self.action, self.origin_state, action_num, prior,
|
||||
self.parent.inverse)
|
||||
mcts=self.mcts, inverse=self.parent.inverse)
|
||||
return value
|
||||
else:
|
||||
return 0.
|
||||
@ -160,11 +170,23 @@ class MCTS(object):
|
||||
if method == "":
|
||||
self.root = root
|
||||
if method == "UCT":
|
||||
self.root = UCTNode(None, None, root, action_num, prior, self.debug, inverse=inverse)
|
||||
self.root = UCTNode(None, None, root, action_num, prior, mcts=self, inverse=inverse)
|
||||
if method == "TS":
|
||||
self.root = TSNode(None, None, root, action_num, prior, inverse=inverse)
|
||||
self.inverse = inverse
|
||||
|
||||
# time spend on each step
|
||||
self.selection_time = 0
|
||||
self.expansion_time = 0
|
||||
self.backpropagation_time = 0
|
||||
self.action_selection_time = 0
|
||||
self.state_selection_time = 0
|
||||
self.simulate_sf_time = 0
|
||||
self.valid_mask_time = 0
|
||||
self.ndarray2list_time = 0
|
||||
self.list2tuple_time = 0
|
||||
self.check = 0
|
||||
|
||||
def search(self, max_step=None, max_time=None):
|
||||
step = 0
|
||||
start_time = time.time()
|
||||
@ -175,23 +197,25 @@ class MCTS(object):
|
||||
if max_step is None and max_time is None:
|
||||
raise ValueError("Need a stop criteria!")
|
||||
|
||||
selection_time = 0
|
||||
expansion_time = 0
|
||||
backprop_time = 0
|
||||
self.root.clear_elapse_time()
|
||||
while step < max_step and time.time() - start_time < max_step:
|
||||
sel_time, exp_time, back_time = self._expand()
|
||||
selection_time += sel_time
|
||||
expansion_time += exp_time
|
||||
backprop_time += back_time
|
||||
self.selection_time += sel_time
|
||||
self.expansion_time += exp_time
|
||||
self.backpropagation_time += back_time
|
||||
step += 1
|
||||
if (self.debug):
|
||||
file = open("debug.txt", "a")
|
||||
file = open("mcts_profiling.txt", "a")
|
||||
file.write("[" + str(self.role) + "]"
|
||||
+ " selection : " + str(selection_time) + "\t"
|
||||
+ " validmask : " + str(self.root.elapse_time) + "\t"
|
||||
+ " expansion : " + str(expansion_time) + "\t"
|
||||
+ " backprop : " + str(backprop_time) + "\t"
|
||||
+ " sel " + '%.3f' % self.selection_time + " "
|
||||
+ " sel_sta " + '%.3f' % self.state_selection_time + " "
|
||||
+ " valid " + '%.3f' % self.valid_mask_time + " "
|
||||
+ " sel_act " + '%.3f' % self.action_selection_time + " "
|
||||
+ " array2list " + '%.4f' % self.ndarray2list_time + " "
|
||||
+ " check " + str(self.check) + " "
|
||||
+ " list2tuple " + '%.4f' % self.list2tuple_time + " \t"
|
||||
+ " forward " + '%.3f' % self.simulate_sf_time + " "
|
||||
+ " exp " + '%.3f' % self.expansion_time + " "
|
||||
+ " bak " + '%.3f' % self.backpropagation_time + " "
|
||||
+ "\n")
|
||||
file.close()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user