Merge branch 'master' of https://github.com/sproblvem/tianshou
This commit is contained in:
commit
88648f0c4b
4
AlphaGo/.gitignore
vendored
4
AlphaGo/.gitignore
vendored
@ -1,3 +1,5 @@
|
||||
data
|
||||
checkpoints
|
||||
checkpoints_origin
|
||||
random
|
||||
*.log
|
||||
*.txt
|
||||
|
||||
@ -1,84 +0,0 @@
|
||||
import os
|
||||
import threading
|
||||
import numpy as np
|
||||
|
||||
size = 9
|
||||
path = "/home/yama/leela-zero/data/npz-files/"
|
||||
name = os.listdir(path)
|
||||
print(len(name))
|
||||
thread_num = 17
|
||||
batch_num = len(name) // thread_num
|
||||
|
||||
def integrate(name, index):
|
||||
boards = np.zeros([0, size, size, 17])
|
||||
wins = np.zeros([0, 1])
|
||||
ps = np.zeros([0, size**2 + 1])
|
||||
for n in name:
|
||||
data = np.load(path + n)
|
||||
board = data["state"]
|
||||
win = data["winner"]
|
||||
p = data["prob"]
|
||||
# board = np.zeros([0, size, size, 17])
|
||||
# win = np.zeros([0, 1])
|
||||
# p = np.zeros([0, size**2 + 1])
|
||||
# for i in range(data["boards"].shape[3]):
|
||||
# board = np.concatenate([board, data["boards"][:,:,:,i].reshape(-1, size, size, 17)], axis=0)
|
||||
# win = np.concatenate([win, data["win"][:,i].reshape(-1, 1)], axis=0)
|
||||
# p = np.concatenate([p, data["p"][:,i].reshape(-1, size**2 + 1)], axis=0)
|
||||
boards = np.concatenate([boards, board], axis=0)
|
||||
wins = np.concatenate([wins, win], axis=0)
|
||||
ps = np.concatenate([ps, p], axis=0)
|
||||
# print("Finish " + n)
|
||||
print ("Integration {} Finished!".format(index))
|
||||
board_ori = boards
|
||||
win_ori = wins
|
||||
p_ori = ps
|
||||
for i in range(1, 3):
|
||||
board = np.rot90(board_ori, i, (1, 2))
|
||||
p = np.concatenate(
|
||||
[np.rot90(p_ori[:, :-1].reshape(-1, size, size), i, (1, 2)).reshape(-1, size**2), p_ori[:, -1].reshape(-1, 1)],
|
||||
axis=1)
|
||||
boards = np.concatenate([boards, board], axis=0)
|
||||
wins = np.concatenate([wins, win_ori], axis=0)
|
||||
ps = np.concatenate([ps, p], axis=0)
|
||||
|
||||
board = board_ori[:, ::-1]
|
||||
p = np.concatenate([p_ori[:, :-1].reshape(-1, size, size)[:, ::-1].reshape(-1, size**2), p_ori[:, -1].reshape(-1, 1)],
|
||||
axis=1)
|
||||
boards = np.concatenate([boards, board], axis=0)
|
||||
wins = np.concatenate([wins, win_ori], axis=0)
|
||||
ps = np.concatenate([ps, p], axis=0)
|
||||
|
||||
board = board_ori[:, :, ::-1]
|
||||
p = np.concatenate([p_ori[:, :-1].reshape(-1, size, size)[:, :, ::-1].reshape(-1, size**2), p_ori[:, -1].reshape(-1, 1)],
|
||||
axis=1)
|
||||
boards = np.concatenate([boards, board], axis=0)
|
||||
wins = np.concatenate([wins, win_ori], axis=0)
|
||||
ps = np.concatenate([ps, p], axis=0)
|
||||
|
||||
board = board_ori[:, ::-1]
|
||||
p = np.concatenate(
|
||||
[np.rot90(p_ori[:, :-1].reshape(-1, size, size)[:, ::-1], 1, (1, 2)).reshape(-1, size**2), p_ori[:, -1].reshape(-1, 1)],
|
||||
axis=1)
|
||||
boards = np.concatenate([boards, np.rot90(board, 1, (1, 2))], axis=0)
|
||||
wins = np.concatenate([wins, win_ori], axis=0)
|
||||
ps = np.concatenate([ps, p], axis=0)
|
||||
|
||||
board = board_ori[:, :, ::-1]
|
||||
p = np.concatenate(
|
||||
[np.rot90(p_ori[:, :-1].reshape(-1, size, size)[:, :, ::-1], 1, (1, 2)).reshape(-1, size**2),
|
||||
p_ori[:, -1].reshape(-1, 1)],
|
||||
axis=1)
|
||||
boards = np.concatenate([boards, np.rot90(board, 1, (1, 2))], axis=0)
|
||||
wins = np.concatenate([wins, win_ori], axis=0)
|
||||
ps = np.concatenate([ps, p], axis=0)
|
||||
|
||||
np.savez("/home/tongzheng/data/data-" + str(index), state=boards, winner=wins, prob=ps)
|
||||
print ("Thread {} has finished.".format(index))
|
||||
thread_list = list()
|
||||
for i in range(thread_num):
|
||||
thread_list.append(threading.Thread(target=integrate, args=(name[batch_num * i:batch_num * (i + 1)], i,)))
|
||||
for thread in thread_list:
|
||||
thread.start()
|
||||
for thread in thread_list:
|
||||
thread.join()
|
||||
29
AlphaGo/data_statistic.py
Normal file
29
AlphaGo/data_statistic.py
Normal file
@ -0,0 +1,29 @@
|
||||
import os
|
||||
import cPickle
|
||||
|
||||
class Data(object):
|
||||
def __init__(self):
|
||||
self.boards = []
|
||||
self.probs = []
|
||||
self.winner = 0
|
||||
|
||||
def file_to_training_data(file_name):
|
||||
with open(file_name, 'rb') as file:
|
||||
try:
|
||||
file.seek(0)
|
||||
data = cPickle.load(file)
|
||||
return data.winner
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
win_count = [0, 0, 0]
|
||||
file_list = os.listdir("./data")
|
||||
#print file_list
|
||||
for file in file_list:
|
||||
win_count[file_to_training_data("./data/" + file)] += 1
|
||||
print "Total play : " + str(len(file_list))
|
||||
print "Black wins : " + str(win_count[1])
|
||||
print "White wins : " + str(win_count[-1])
|
||||
|
||||
@ -6,13 +6,13 @@
|
||||
#
|
||||
|
||||
from game import Game
|
||||
import copy
|
||||
import numpy as np
|
||||
import utils
|
||||
|
||||
|
||||
class GTPEngine():
|
||||
def __init__(self, **kwargs):
|
||||
self.size = 9
|
||||
self.komi = 6.5
|
||||
try:
|
||||
self._game = kwargs['game_obj']
|
||||
self._game.clear()
|
||||
@ -141,11 +141,9 @@ class GTPEngine():
|
||||
self.disconnect = True
|
||||
return None, True
|
||||
|
||||
def cmd_boardsize(self, args, **kwargs):
|
||||
if args.isdigit():
|
||||
size = int(args)
|
||||
self.size = size
|
||||
self._game.set_size(size)
|
||||
def cmd_boardsize(self, board_size, **kwargs):
|
||||
if board_size.isdigit():
|
||||
self._game.set_size(int(board_size))
|
||||
return None, True
|
||||
else:
|
||||
return 'non digit size', False
|
||||
@ -154,11 +152,9 @@ class GTPEngine():
|
||||
self._game.clear()
|
||||
return None, True
|
||||
|
||||
def cmd_komi(self, args, **kwargs):
|
||||
def cmd_komi(self, komi, **kwargs):
|
||||
try:
|
||||
komi = float(args)
|
||||
self.komi = komi
|
||||
self._game.set_komi(komi)
|
||||
self._game.set_komi(float(komi))
|
||||
return None, True
|
||||
except ValueError:
|
||||
raise ValueError("syntax error")
|
||||
@ -186,12 +182,14 @@ class GTPEngine():
|
||||
return self._game.game_engine.executor_get_score(self._game.board), True
|
||||
|
||||
def cmd_show_board(self, args, **kwargs):
|
||||
return self._game.board, True
|
||||
board = copy.deepcopy(self._game.board)
|
||||
if isinstance(board, np.ndarray):
|
||||
board = board.flatten().tolist()
|
||||
return board, True
|
||||
|
||||
def cmd_get_prob(self, args, **kwargs):
|
||||
return self._game.prob, True
|
||||
|
||||
|
||||
if __name__ == "main":
|
||||
game = Game()
|
||||
engine = GTPEngine(game_obj=game)
|
||||
print ("test engine.py")
|
||||
|
||||
@ -17,39 +17,50 @@ from tianshou.core.mcts.mcts import MCTS
|
||||
|
||||
import go
|
||||
import reversi
|
||||
import time
|
||||
|
||||
class Game:
|
||||
'''
|
||||
Load the real game and trained weights.
|
||||
|
||||
TODO : Maybe merge with the engine class in future,
|
||||
|
||||
TODO : Maybe merge with the engine class in future,
|
||||
currently leave it untouched for interacting with Go UI.
|
||||
'''
|
||||
def __init__(self, name="go", checkpoint_path=None):
|
||||
def __init__(self, name=None, role=None, debug=False, checkpoint_path=None):
|
||||
self.name = name
|
||||
if role is None:
|
||||
raise ValueError("Need a role!")
|
||||
self.role = role
|
||||
self.debug = debug
|
||||
if self.name == "go":
|
||||
self.size = 9
|
||||
self.komi = 3.75
|
||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||
self.history = []
|
||||
self.history_length = 8
|
||||
self.latest_boards = deque(maxlen=8)
|
||||
for _ in range(8):
|
||||
self.latest_boards.append(self.board)
|
||||
self.game_engine = go.Go(size=self.size, komi=self.komi)
|
||||
self.history = []
|
||||
self.history_set = set()
|
||||
self.game_engine = go.Go(size=self.size, komi=self.komi, role=self.role)
|
||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||
elif self.name == "reversi":
|
||||
self.size = 8
|
||||
self.history_length = 1
|
||||
self.game_engine = reversi.Reversi()
|
||||
self.history = []
|
||||
self.game_engine = reversi.Reversi(size=self.size)
|
||||
self.board = self.game_engine.get_board()
|
||||
else:
|
||||
raise ValueError(name + " is an unknown game...")
|
||||
|
||||
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length)
|
||||
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length,
|
||||
checkpoint_path=checkpoint_path)
|
||||
self.latest_boards = deque(maxlen=self.history_length)
|
||||
for _ in range(self.history_length):
|
||||
self.latest_boards.append(self.board)
|
||||
|
||||
def clear(self):
|
||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||
self.history = []
|
||||
if self.name == "go":
|
||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||
self.history = []
|
||||
if self.name == "reversi":
|
||||
self.board = self.game_engine.get_board()
|
||||
for _ in range(self.history_length):
|
||||
self.latest_boards.append(self.board)
|
||||
|
||||
@ -61,8 +72,16 @@ class Game:
|
||||
self.komi = k
|
||||
|
||||
def think(self, latest_boards, color):
|
||||
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], self.size ** 2 + 1, inverse=True)
|
||||
mcts.search(max_step=20)
|
||||
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color],
|
||||
self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True)
|
||||
mcts.search(max_step=100)
|
||||
if self.debug:
|
||||
file = open("mcts_debug.log", 'ab')
|
||||
np.savetxt(file, mcts.root.Q, header="\n" + self.role + " Q value : ", fmt='%.4f', newline=", ")
|
||||
np.savetxt(file, mcts.root.W, header="\n" + self.role + " W value : ", fmt='%.4f', newline=", ")
|
||||
np.savetxt(file, mcts.root.N, header="\n" + self.role + " N value : ", fmt="%d", newline=", ")
|
||||
np.savetxt(file, mcts.root.prior, header="\n" + self.role + " prior : ", fmt='%.4f', newline=", ")
|
||||
file.close()
|
||||
temp = 1
|
||||
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
||||
choice = np.random.choice(self.size ** 2 + 1, 1, p=prob).tolist()[0]
|
||||
@ -76,11 +95,10 @@ class Game:
|
||||
# this function can be called directly to play the opponent's move
|
||||
if vertex == utils.PASS:
|
||||
return True
|
||||
# TODO this implementation is not very elegant
|
||||
if self.name == "go":
|
||||
if self.name == "reversi":
|
||||
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
||||
elif self.name == "reversi":
|
||||
res = self.game_engine.executor_do_move(self.board, color, vertex)
|
||||
if self.name == "go":
|
||||
res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board, color, vertex)
|
||||
return res
|
||||
|
||||
def think_play_move(self, color):
|
||||
@ -106,14 +124,12 @@ class Game:
|
||||
if row[i] < 10:
|
||||
print(' ', end='')
|
||||
for j in range(self.size):
|
||||
print(self.status2symbol(self.board[self._flatten((j + 1, i + 1))]), end=' ')
|
||||
print(self.status2symbol(self.board[self.game_engine._flatten((j + 1, i + 1))]), end=' ')
|
||||
print('')
|
||||
sys.stdout.flush()
|
||||
|
||||
if __name__ == "__main__":
|
||||
g = Game()
|
||||
g.show_board()
|
||||
g.think_play_move(1)
|
||||
#file = open("debug.txt", "a")
|
||||
#file.write("mcts check\n")
|
||||
#file.close()
|
||||
game = Game(name="reversi", role="black", checkpoint_path=None)
|
||||
game.debug = True
|
||||
game.think_play_move(utils.BLACK)
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import utils
|
||||
import copy
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
import time
|
||||
'''
|
||||
Settings of the Go game.
|
||||
|
||||
@ -18,6 +18,7 @@ class Go:
|
||||
def __init__(self, **kwargs):
|
||||
self.size = kwargs['size']
|
||||
self.komi = kwargs['komi']
|
||||
self.role = kwargs['role']
|
||||
|
||||
def _flatten(self, vertex):
|
||||
x, y = vertex
|
||||
@ -96,12 +97,12 @@ class Go:
|
||||
for b in group:
|
||||
current_board[self._flatten(b)] = utils.EMPTY
|
||||
|
||||
def _check_global_isomorphous(self, history_boards, current_board, color, vertex):
|
||||
def _check_global_isomorphous(self, history_boards_set, current_board, color, vertex):
|
||||
repeat = False
|
||||
next_board = copy.copy(current_board)
|
||||
next_board = copy.deepcopy(current_board)
|
||||
next_board[self._flatten(vertex)] = color
|
||||
self._process_board(next_board, color, vertex)
|
||||
if next_board in history_boards:
|
||||
if hash(tuple(next_board)) in history_boards_set:
|
||||
repeat = True
|
||||
return repeat
|
||||
|
||||
@ -157,7 +158,7 @@ class Go:
|
||||
vertex = self._deflatten(action)
|
||||
return vertex
|
||||
|
||||
def _rule_check(self, history_boards, current_board, color, vertex):
|
||||
def _rule_check(self, history_boards_set, current_board, color, vertex):
|
||||
### in board
|
||||
if not self._in_board(vertex):
|
||||
return False
|
||||
@ -171,7 +172,7 @@ class Go:
|
||||
return False
|
||||
|
||||
### forbid global isomorphous
|
||||
if self._check_global_isomorphous(history_boards, current_board, color, vertex):
|
||||
if self._check_global_isomorphous(history_boards_set, current_board, color, vertex):
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -211,23 +212,28 @@ class Go:
|
||||
|
||||
def simulate_step_forward(self, state, action):
|
||||
# initialize the simulate_board from state
|
||||
history_boards, color = state
|
||||
history_boards, color = copy.deepcopy(state)
|
||||
if history_boards[-1] == history_boards[-2] and action is utils.PASS:
|
||||
return None, 2 * (float(self.executor_get_score(history_boards[-1]) > 0)-0.5) * color
|
||||
return None, 2 * (float(self.simple_executor_get_score(history_boards[-1]) > 0)-0.5) * color
|
||||
else:
|
||||
vertex = self._action2vertex(action)
|
||||
new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex)
|
||||
new_board = self._do_move(copy.deepcopy(history_boards[-1]), color, vertex)
|
||||
history_boards.append(new_board)
|
||||
new_color = -color
|
||||
return [history_boards, new_color], 0
|
||||
|
||||
def executor_do_move(self, history, latest_boards, current_board, color, vertex):
|
||||
if not self._rule_check(history, current_board, color, vertex):
|
||||
def simulate_hashable_conversion(self, state):
|
||||
# since go is MDP, we only need the last board for hashing
|
||||
return tuple(state[0][-1])
|
||||
|
||||
def executor_do_move(self, history, history_set, latest_boards, current_board, color, vertex):
|
||||
if not self._rule_check(history_set, current_board, color, vertex):
|
||||
return False
|
||||
current_board[self._flatten(vertex)] = color
|
||||
self._process_board(current_board, color, vertex)
|
||||
history.append(copy.copy(current_board))
|
||||
latest_boards.append(copy.copy(current_board))
|
||||
history.append(copy.deepcopy(current_board))
|
||||
latest_boards.append(copy.deepcopy(current_board))
|
||||
history_set.add(hash(tuple(current_board)))
|
||||
return True
|
||||
|
||||
def _find_empty(self, current_board):
|
||||
@ -284,10 +290,7 @@ class Go:
|
||||
return utils.WHITE
|
||||
|
||||
def executor_get_score(self, current_board):
|
||||
'''
|
||||
is_unknown_estimation: whether use nearby stone to predict the unknown
|
||||
return score from BLACK perspective.
|
||||
'''
|
||||
#return score from BLACK perspective.
|
||||
_board = copy.deepcopy(current_board)
|
||||
while utils.EMPTY in _board:
|
||||
vertex = self._find_empty(_board)
|
||||
@ -309,7 +312,46 @@ class Go:
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def simple_executor_get_score(self, current_board):
|
||||
'''
|
||||
can only be used for the empty group only have one single stone
|
||||
return score from BLACK perspective.
|
||||
'''
|
||||
score = 0
|
||||
for idx, color in enumerate(current_board):
|
||||
if color == utils.EMPTY:
|
||||
neighbors = self._neighbor(self._deflatten(idx))
|
||||
color = current_board[self._flatten(neighbors[0])]
|
||||
if color == utils.BLACK:
|
||||
score += 1
|
||||
elif color == utils.WHITE:
|
||||
score -= 1
|
||||
score -= self.komi
|
||||
return score
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
go = Go(size=9, komi=3.75, role = utils.BLACK)
|
||||
endgame = [
|
||||
1, 0, 1, 0, 1, 1, -1, 0, -1,
|
||||
1, 1, 1, 1, 1, 1, -1, -1, -1,
|
||||
0, 1, 1, 1, 1, -1, 0, -1, 0,
|
||||
1, 1, 1, 1, 1, -1, -1, -1, -1,
|
||||
1, -1, 1, -1, 1, 1, -1, -1, -1,
|
||||
-1, -1, -1, -1, -1, 1, -1, 0, -1,
|
||||
1, 1, 1, -1, -1, -1, -1, -1, -1,
|
||||
1, 0, 1, 1, 1, 1, 1, -1, 0,
|
||||
1, 1, 0, 1, -1, -1, -1, -1, -1
|
||||
]
|
||||
time0 = time.time()
|
||||
score = go.executor_get_score(endgame)
|
||||
time1 = time.time()
|
||||
print(score, time1 - time0)
|
||||
score = go.new_executor_get_score(endgame)
|
||||
time2 = time.time()
|
||||
print(score, time2 - time1)
|
||||
'''
|
||||
### do unit test for Go class
|
||||
pure_test = [
|
||||
0, 1, 0, 1, 0, 1, 0, 0, 0,
|
||||
@ -348,3 +390,4 @@ if __name__ == "__main__":
|
||||
for i in range(7):
|
||||
print (go._is_eye(opponent_test, utils.BLACK, ot_qry[i]))
|
||||
print("Test of eye surrend by opponents\n")
|
||||
'''
|
||||
|
||||
@ -80,7 +80,7 @@ class Data(object):
|
||||
|
||||
|
||||
class ResNet(object):
|
||||
def __init__(self, board_size, action_num, history_length=1, residual_block_num=20, checkpoint_path=None):
|
||||
def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, checkpoint_path=None):
|
||||
"""
|
||||
the resnet model
|
||||
|
||||
@ -101,7 +101,7 @@ class ResNet(object):
|
||||
self._build_network(residual_block_num, self.checkpoint_path)
|
||||
|
||||
# training hyper-parameters:
|
||||
self.window_length = 7000
|
||||
self.window_length = 3
|
||||
self.save_freq = 5000
|
||||
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
|
||||
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
|
||||
@ -124,6 +124,7 @@ class ResNet(object):
|
||||
h = residual_block(h, self.is_training)
|
||||
self.v = value_head(h, self.is_training)
|
||||
self.p = policy_head(h, self.is_training, self.action_num)
|
||||
self.prob = tf.nn.softmax(self.p)
|
||||
self.value_loss = tf.reduce_mean(tf.square(self.z - self.v))
|
||||
self.policy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.pi, logits=self.p))
|
||||
|
||||
@ -152,13 +153,16 @@ class ResNet(object):
|
||||
:param color: a string, indicate which one to play
|
||||
:return: a list of tensor, the predicted value and policy given the history and color
|
||||
"""
|
||||
# Note : maybe we can use it for isolating test of MCTS
|
||||
#prob = [1.0 / self.action_num] * self.action_num
|
||||
#return [prob, np.random.uniform(-1, 1)]
|
||||
history, color = state
|
||||
if len(history) != self.history_length:
|
||||
raise ValueError(
|
||||
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
|
||||
self.history_length))
|
||||
state = self._history2state(history, color)
|
||||
return self.sess.run([self.p, self.v], feed_dict={self.x: state, self.is_training: False})
|
||||
eval_state = self._history2state(history, color)
|
||||
return self.sess.run([self.prob, self.v], feed_dict={self.x: eval_state, self.is_training: False})
|
||||
|
||||
def _history2state(self, history, color):
|
||||
"""
|
||||
@ -170,10 +174,10 @@ class ResNet(object):
|
||||
"""
|
||||
state = np.zeros([1, self.board_size, self.board_size, 2 * self.history_length + 1])
|
||||
for i in range(self.history_length):
|
||||
state[0, :, :, i] = np.array(np.array(history[i]) == np.ones(self.board_size ** 2)).reshape(self.board_size,
|
||||
state[0, :, :, i] = np.array(np.array(history[i]).flatten() == np.ones(self.board_size ** 2)).reshape(self.board_size,
|
||||
self.board_size)
|
||||
state[0, :, :, i + self.history_length] = np.array(
|
||||
np.array(history[i]) == -np.ones(self.board_size ** 2)).reshape(self.board_size, self.board_size)
|
||||
np.array(history[i]).flatten() == -np.ones(self.board_size ** 2)).reshape(self.board_size, self.board_size)
|
||||
# TODO: need a config to specify the BLACK and WHITE
|
||||
if color == +1:
|
||||
state[0, :, :, 2 * self.history_length] = np.ones([self.board_size, self.board_size])
|
||||
@ -223,8 +227,8 @@ class ResNet(object):
|
||||
else:
|
||||
start_time = time.time()
|
||||
for i in range(batch_size):
|
||||
priority = self.training_data['length'] / sum(self.training_data['length'])
|
||||
game_num = np.random.choice(self.window_length, 1, p=priority)
|
||||
priority = np.array(self.training_data['length']) / (0.0 + np.sum(np.array(self.training_data['length'])))
|
||||
game_num = np.random.choice(self.window_length, 1, p=priority)[0]
|
||||
state_num = np.random.randint(self.training_data['length'][game_num])
|
||||
rotate_times = np.random.randint(4)
|
||||
reflect_times = np.random.randint(2)
|
||||
@ -232,12 +236,10 @@ class ResNet(object):
|
||||
training_data['states'].append(
|
||||
self._preprocession(self.training_data['states'][game_num][state_num], reflect_times,
|
||||
reflect_orientation, rotate_times))
|
||||
training_data['probs'].append(
|
||||
self._preprocession(self.training_data['probs'][game_num][state_num], reflect_times,
|
||||
reflect_orientation, rotate_times))
|
||||
training_data['winner'].append(
|
||||
self._preprocession(self.training_data['winner'][game_num][state_num], reflect_times,
|
||||
reflect_orientation, rotate_times))
|
||||
training_data['probs'].append(np.concatenate(
|
||||
[self._preprocession(self.training_data['probs'][game_num][state_num][:-1].reshape(self.board_size, self.board_size, 1), reflect_times,
|
||||
reflect_orientation, rotate_times).reshape(1, self.board_size**2), self.training_data['probs'][game_num][state_num][-1].reshape(1,1)], axis=1))
|
||||
training_data['winner'].append(self.training_data['winner'][game_num][state_num].reshape(1, 1))
|
||||
value_loss, policy_loss, reg, _ = self.sess.run(
|
||||
[self.value_loss, self.policy_loss, self.reg, self.train_op],
|
||||
feed_dict={self.x: np.concatenate(training_data['states'], axis=0),
|
||||
@ -300,9 +302,9 @@ class ResNet(object):
|
||||
:return:
|
||||
"""
|
||||
|
||||
new_board = copy.copy(board)
|
||||
new_board = copy.deepcopy(board)
|
||||
if new_board.ndim == 3:
|
||||
np.expand_dims(new_board, axis=0)
|
||||
new_board = np.expand_dims(new_board, axis=0)
|
||||
|
||||
new_board = self._board_reflection(new_board, reflect_times, reflect_orientation)
|
||||
new_board = self._board_rotation(new_board, rotate_times)
|
||||
@ -330,7 +332,7 @@ class ResNet(object):
|
||||
:param orientation: an integer, which orientation to reflect
|
||||
:return:
|
||||
"""
|
||||
new_board = copy.copy(board)
|
||||
new_board = copy.deepcopy(board)
|
||||
for _ in range(times):
|
||||
if orientation == 0:
|
||||
new_board = new_board[:, ::-1]
|
||||
|
||||
@ -1,225 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
import multi_gpu
|
||||
import time
|
||||
import copy
|
||||
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
|
||||
|
||||
def residual_block(input, is_training):
|
||||
normalizer_params = {'is_training': is_training,
|
||||
'updates_collections': tf.GraphKeys.UPDATE_OPS}
|
||||
h = layers.conv2d(input, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu,
|
||||
normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params,
|
||||
weights_regularizer=layers.l2_regularizer(1e-4))
|
||||
h = layers.conv2d(h, 256, kernel_size=3, stride=1, activation_fn=tf.identity,
|
||||
normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params,
|
||||
weights_regularizer=layers.l2_regularizer(1e-4))
|
||||
h = h + input
|
||||
return tf.nn.relu(h)
|
||||
|
||||
|
||||
def policy_heads(input, is_training):
|
||||
normalizer_params = {'is_training': is_training,
|
||||
'updates_collections': tf.GraphKeys.UPDATE_OPS}
|
||||
h = layers.conv2d(input, 2, kernel_size=1, stride=1, activation_fn=tf.nn.relu,
|
||||
normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params,
|
||||
weights_regularizer=layers.l2_regularizer(1e-4))
|
||||
h = layers.flatten(h)
|
||||
h = layers.fully_connected(h, 82, activation_fn=tf.identity, weights_regularizer=layers.l2_regularizer(1e-4))
|
||||
return h
|
||||
|
||||
|
||||
def value_heads(input, is_training):
|
||||
normalizer_params = {'is_training': is_training,
|
||||
'updates_collections': tf.GraphKeys.UPDATE_OPS}
|
||||
h = layers.conv2d(input, 2, kernel_size=1, stride=1, activation_fn=tf.nn.relu,
|
||||
normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params,
|
||||
weights_regularizer=layers.l2_regularizer(1e-4))
|
||||
h = layers.flatten(h)
|
||||
h = layers.fully_connected(h, 256, activation_fn=tf.nn.relu, weights_regularizer=layers.l2_regularizer(1e-4))
|
||||
h = layers.fully_connected(h, 1, activation_fn=tf.nn.tanh, weights_regularizer=layers.l2_regularizer(1e-4))
|
||||
return h
|
||||
|
||||
|
||||
class Network(object):
|
||||
def __init__(self):
|
||||
self.x = tf.placeholder(tf.float32, shape=[None, 9, 9, 17])
|
||||
self.is_training = tf.placeholder(tf.bool, shape=[])
|
||||
self.z = tf.placeholder(tf.float32, shape=[None, 1])
|
||||
self.pi = tf.placeholder(tf.float32, shape=[None, 82])
|
||||
self.build_network()
|
||||
|
||||
def build_network(self):
|
||||
h = layers.conv2d(self.x, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu,
|
||||
normalizer_fn=layers.batch_norm,
|
||||
normalizer_params={'is_training': self.is_training,
|
||||
'updates_collections': tf.GraphKeys.UPDATE_OPS},
|
||||
weights_regularizer=layers.l2_regularizer(1e-4))
|
||||
for i in range(4):
|
||||
h = residual_block(h, self.is_training)
|
||||
self.v = value_heads(h, self.is_training)
|
||||
self.p = policy_heads(h, self.is_training)
|
||||
# loss = tf.reduce_mean(tf.square(z-v)) - tf.multiply(pi, tf.log(tf.clip_by_value(tf.nn.softmax(p), 1e-8, tf.reduce_max(tf.nn.softmax(p)))))
|
||||
self.value_loss = tf.reduce_mean(tf.square(self.z - self.v))
|
||||
self.policy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.pi, logits=self.p))
|
||||
|
||||
self.reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
|
||||
self.total_loss = self.value_loss + self.policy_loss + self.reg
|
||||
# train_op = tf.train.MomentumOptimizer(1e-4, momentum=0.9, use_nesterov=True).minimize(total_loss)
|
||||
self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
||||
with tf.control_dependencies(self.update_ops):
|
||||
self.train_op = tf.train.RMSPropOptimizer(1e-4).minimize(self.total_loss)
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
||||
self.saver = tf.train.Saver(max_to_keep=10, var_list=self.var_list)
|
||||
self.sess = multi_gpu.create_session()
|
||||
|
||||
def train(self):
|
||||
data_path = "./training_data/"
|
||||
data_name = os.listdir(data_path)
|
||||
epochs = 100
|
||||
batch_size = 128
|
||||
|
||||
result_path = "./checkpoints_origin/"
|
||||
with multi_gpu.create_session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
ckpt_file = tf.train.latest_checkpoint(result_path)
|
||||
if ckpt_file is not None:
|
||||
print('Restoring model from {}...'.format(ckpt_file))
|
||||
self.saver.restore(sess, ckpt_file)
|
||||
for epoch in range(epochs):
|
||||
for name in data_name:
|
||||
data = np.load(data_path + name)
|
||||
boards = data["boards"]
|
||||
wins = data["wins"]
|
||||
ps = data["ps"]
|
||||
print (boards.shape)
|
||||
print (wins.shape)
|
||||
print (ps.shape)
|
||||
batch_num = boards.shape[0] // batch_size
|
||||
index = np.arange(boards.shape[0])
|
||||
np.random.shuffle(index)
|
||||
value_losses = []
|
||||
policy_losses = []
|
||||
regs = []
|
||||
time_train = -time.time()
|
||||
for iter in range(batch_num):
|
||||
lv, lp, r, value, prob, _ = sess.run(
|
||||
[self.value_loss, self.policy_loss, self.reg, self.v, tf.nn.softmax(self.p), self.train_op],
|
||||
feed_dict={self.x: boards[
|
||||
index[iter * batch_size:(iter + 1) * batch_size]],
|
||||
self.z: wins[index[
|
||||
iter * batch_size:(iter + 1) * batch_size]],
|
||||
self.pi: ps[index[
|
||||
iter * batch_size:(iter + 1) * batch_size]],
|
||||
self.is_training: True})
|
||||
value_losses.append(lv)
|
||||
policy_losses.append(lp)
|
||||
regs.append(r)
|
||||
if iter % 1 == 0:
|
||||
print(
|
||||
"Epoch: {}, Part {}, Iteration: {}, Time: {}, Value Loss: {}, Policy Loss: {}, Reg: {}".format(
|
||||
epoch, name, iter, time.time() + time_train, np.mean(np.array(value_losses)),
|
||||
np.mean(np.array(policy_losses)), np.mean(np.array(regs))))
|
||||
time_train = -time.time()
|
||||
value_losses = []
|
||||
policy_losses = []
|
||||
regs = []
|
||||
if iter % 20 == 0:
|
||||
save_path = "Epoch{}.Part{}.Iteration{}.ckpt".format(epoch, name, iter)
|
||||
self.saver.save(sess, result_path + save_path)
|
||||
del data, boards, wins, ps
|
||||
|
||||
|
||||
# def forward(call_number):
|
||||
# # checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints"
|
||||
# checkpoint_path = "/home/jialian/stuGo/tianshou/stuGo/checkpoints/"
|
||||
# board_file = np.genfromtxt("/home/jialian/stuGo/tianshou/leela-zero/src/mcts_nn_files/board_" + call_number,
|
||||
# dtype='str');
|
||||
# human_board = np.zeros((17, 19, 19))
|
||||
#
|
||||
# # TODO : is it ok to ignore the last channel?
|
||||
# for i in range(17):
|
||||
# human_board[i] = np.array(list(board_file[i])).reshape(19, 19)
|
||||
# # print("============================")
|
||||
# # print("human board sum : " + str(np.sum(human_board[-1])))
|
||||
# # print("============================")
|
||||
# # print(human_board)
|
||||
# # print("============================")
|
||||
# # rint(human_board)
|
||||
# feed_board = human_board.transpose(1, 2, 0).reshape(1, 19, 19, 17)
|
||||
# # print(feed_board[:,:,:,-1])
|
||||
# # print(feed_board.shape)
|
||||
#
|
||||
# # npz_board = np.load("/home/yama/rl/tianshou/AlphaGo/data/7f83928932f64a79bc1efdea268698ae.npz")
|
||||
# # print(npz_board["boards"].shape)
|
||||
# # feed_board = npz_board["boards"][10].reshape(-1, 19, 19, 17)
|
||||
# ##print(feed_board)
|
||||
# # show_board = feed_board[0].transpose(2, 0, 1)
|
||||
# # print("board shape : ", show_board.shape)
|
||||
# # print(show_board)
|
||||
#
|
||||
# itflag = False
|
||||
# with multi_gpu.create_session() as sess:
|
||||
# sess.run(tf.global_variables_initializer())
|
||||
# ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
|
||||
# if ckpt_file is not None:
|
||||
# # print('Restoring model from {}...'.format(ckpt_file))
|
||||
# saver.restore(sess, ckpt_file)
|
||||
# else:
|
||||
# raise ValueError("No model loaded")
|
||||
# res = sess.run([tf.nn.softmax(p), v], feed_dict={x: feed_board, is_training: itflag})
|
||||
# # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
|
||||
# # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
|
||||
# # print(np.argmax(res[0]))
|
||||
# np.savetxt(sys.stdout, res[0][0], fmt="%.6f", newline=" ")
|
||||
# np.savetxt(sys.stdout, res[1][0], fmt="%.6f", newline=" ")
|
||||
# pv_file = "/home/jialian/stuGotianshou/leela-zero/src/mcts_nn_files/policy_value"
|
||||
# np.savetxt(pv_file, np.concatenate((res[0][0], res[1][0])), fmt="%.6f", newline=" ")
|
||||
# # np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ")
|
||||
# return res
|
||||
|
||||
def forward(self, checkpoint_path):
|
||||
# checkpoint_path = "/home/tongzheng/tianshou/AlphaGo/checkpoints/"
|
||||
# sess = multi_gpu.create_session()
|
||||
# sess.run(tf.global_variables_initializer())
|
||||
if checkpoint_path is None:
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
else:
|
||||
ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
|
||||
if ckpt_file is not None:
|
||||
# print('Restoring model from {}...'.format(ckpt_file))
|
||||
self.saver.restore(self.sess, ckpt_file)
|
||||
# print('Successfully loaded')
|
||||
else:
|
||||
raise ValueError("No model loaded")
|
||||
# prior, value = sess.run([tf.nn.softmax(p), v], feed_dict={x: state, is_training: False})
|
||||
# return prior, value
|
||||
return self.sess
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# state = np.random.randint(0, 1, [256, 9, 9, 17])
|
||||
# net = Network()
|
||||
# net.train()
|
||||
# sess = net.forward()
|
||||
# start_time = time.time()
|
||||
# for i in range(100):
|
||||
# sess.run([tf.nn.softmax(net.p), net.v], feed_dict={net.x: state, net.is_training: False})
|
||||
# print("Step {}, use time {}".format(i, time.time() - start_time))
|
||||
# start_time = time.time()
|
||||
net0 = Network()
|
||||
sess0 = net0.forward("./checkpoints/")
|
||||
print("Loaded")
|
||||
while True:
|
||||
pass
|
||||
|
||||
143
AlphaGo/play.py
143
AlphaGo/play.py
@ -5,7 +5,15 @@ import re
|
||||
import Pyro4
|
||||
import time
|
||||
import os
|
||||
import cPickle
|
||||
import utils
|
||||
from time import gmtime, strftime
|
||||
|
||||
python_version = sys.version_info
|
||||
|
||||
if python_version < (3, 0):
|
||||
import cPickle
|
||||
else:
|
||||
import _pickle as cPickle
|
||||
|
||||
class Data(object):
|
||||
def __init__(self):
|
||||
@ -16,7 +24,6 @@ class Data(object):
|
||||
def reset(self):
|
||||
self.__init__()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
Starting two different players which load network weights to evaluate the winning ratio.
|
||||
@ -24,57 +31,90 @@ if __name__ == '__main__':
|
||||
"""
|
||||
# TODO : we should set the network path in a more configurable way.
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--result_path", type=str, default="./data/")
|
||||
parser.add_argument("--data_path", type=str, default="./data/")
|
||||
parser.add_argument("--black_weight_path", type=str, default=None)
|
||||
parser.add_argument("--white_weight_path", type=str, default=None)
|
||||
parser.add_argument("--id", type=int, default=0)
|
||||
parser.add_argument("--id", type=int, default=-1)
|
||||
parser.add_argument("--debug", type=bool, default=False)
|
||||
parser.add_argument("--game", type=str, default="go")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.result_path):
|
||||
os.mkdir(args.result_path)
|
||||
if not os.path.exists(args.data_path):
|
||||
os.mkdir(args.data_path)
|
||||
# black_weight_path = "./checkpoints"
|
||||
# white_weight_path = "./checkpoints_origin"
|
||||
if args.black_weight_path is not None and (not os.path.exists(args.black_weight_path)):
|
||||
raise ValueError("Can't not find the network weights for black player.")
|
||||
raise ValueError("Can't find the network weights for black player.")
|
||||
if args.white_weight_path is not None and (not os.path.exists(args.white_weight_path)):
|
||||
raise ValueError("Can't not find the network weights for white player.")
|
||||
raise ValueError("Can't find the network weights for white player.")
|
||||
|
||||
# kill the old server
|
||||
# kill_old_server = subprocess.Popen(['killall', 'pyro4-ns'])
|
||||
# print "kill the old pyro4 name server, the return code is : " + str(kill_old_server.wait())
|
||||
# time.sleep(1)
|
||||
|
||||
# start a name server to find the remote object
|
||||
# start_new_server = subprocess.Popen(['pyro4-ns', '&'])
|
||||
# print "Start Name Sever : " + str(start_new_server.pid) # + str(start_new_server.wait())
|
||||
# time.sleep(1)
|
||||
|
||||
# start a name server if no name server exists
|
||||
if len(os.popen('ps aux | grep pyro4-ns | grep -v grep').readlines()) == 0:
|
||||
start_new_server = subprocess.Popen(['pyro4-ns', '&'])
|
||||
print "Start Name Sever : " + str(start_new_server.pid) # + str(start_new_server.wait())
|
||||
print("Start Name Sever : " + str(start_new_server.pid)) # + str(start_new_server.wait())
|
||||
time.sleep(1)
|
||||
|
||||
# start two different player with different network weights.
|
||||
server_list = subprocess.check_output(['pyro4-nsc', 'list'])
|
||||
index = []
|
||||
if server_list is not None:
|
||||
server_list = server_list.split("\n")[3:-2]
|
||||
for s in server_list:
|
||||
id = s.split(" ")[0][5:]
|
||||
index.append(eval(id))
|
||||
index.sort()
|
||||
if args.id == -1:
|
||||
if index:
|
||||
args.id = index[-1] + 1
|
||||
else:
|
||||
args.id = 0
|
||||
else:
|
||||
if args.id in index:
|
||||
raise ValueError("Name exists in name server!")
|
||||
|
||||
black_role_name = 'black' + str(args.id)
|
||||
white_role_name = 'white' + str(args.id)
|
||||
|
||||
agent_v0 = subprocess.Popen(
|
||||
['python', '-u', 'player.py', '--role=' + black_role_name, '--checkpoint_path=' + str(args.black_weight_path)],
|
||||
black_player = subprocess.Popen(
|
||||
['python', '-u', 'player.py', '--game=' + args.game, '--role=' + black_role_name,
|
||||
'--checkpoint_path=' + str(args.black_weight_path), '--debug=' + str(args.debug)],
|
||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
bp_output = black_player.stdout.readline()
|
||||
bp_message = bp_output
|
||||
# '' means player.py failed to start, "Start requestLoop" means player.py start successfully
|
||||
while bp_output != '' and "Start requestLoop" not in bp_output:
|
||||
bp_output = black_player.stdout.readline()
|
||||
bp_message += bp_output
|
||||
print("============ " + black_role_name + " message ============" + "\n" + bp_message),
|
||||
|
||||
agent_v1 = subprocess.Popen(
|
||||
['python', '-u', 'player.py', '--role=' + white_role_name, '--checkpoint_path=' + str(args.white_weight_path)],
|
||||
white_player = subprocess.Popen(
|
||||
['python', '-u', 'player.py', '--game=' + args.game, '--role=' + white_role_name,
|
||||
'--checkpoint_path=' + str(args.white_weight_path), '--debug=' + str(args.debug)],
|
||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
wp_output = white_player.stdout.readline()
|
||||
wp_message = wp_output
|
||||
while wp_output != '' and "Start requestLoop" not in wp_output:
|
||||
wp_output = white_player.stdout.readline()
|
||||
wp_message += wp_output
|
||||
print("============ " + white_role_name + " message ============" + "\n" + wp_message),
|
||||
|
||||
server_list = ""
|
||||
while (black_role_name not in server_list) or (white_role_name not in server_list):
|
||||
server_list = subprocess.check_output(['pyro4-nsc', 'list'])
|
||||
print "Waiting for the server start..."
|
||||
if python_version < (3, 0):
|
||||
# TODO : @renyong what is the difference between those two options?
|
||||
server_list = subprocess.check_output(['pyro4-nsc', 'list'])
|
||||
else:
|
||||
server_list = subprocess.check_output(['pyro4-nsc', 'list'])
|
||||
print("Waiting for the server start...")
|
||||
time.sleep(1)
|
||||
print server_list
|
||||
print "Start black player at : " + str(agent_v0.pid)
|
||||
print "Start white player at : " + str(agent_v1.pid)
|
||||
print(server_list)
|
||||
print("Start black player at : " + str(black_player.pid))
|
||||
print("Start white player at : " + str(white_player.pid))
|
||||
|
||||
data = Data()
|
||||
player = [None] * 2
|
||||
@ -86,29 +126,31 @@ if __name__ == '__main__':
|
||||
|
||||
pattern = "[A-Z]{1}[0-9]{1}"
|
||||
space = re.compile("\s+")
|
||||
size = 9
|
||||
size = {"go":9, "reversi":8}
|
||||
show = ['.', 'X', 'O']
|
||||
|
||||
evaluate_rounds = 1
|
||||
evaluate_rounds = 100
|
||||
game_num = 0
|
||||
try:
|
||||
while True:
|
||||
#while True:
|
||||
while game_num < evaluate_rounds:
|
||||
start_time = time.time()
|
||||
num = 0
|
||||
pass_flag = [False, False]
|
||||
print("Start game {}".format(game_num))
|
||||
# end the game if both palyer chose to pass, or play too much turns
|
||||
while not (pass_flag[0] and pass_flag[1]) and num < size ** 2 * 2:
|
||||
while not (pass_flag[0] and pass_flag[1]) and num < size[args.game] ** 2 * 2:
|
||||
turn = num % 2
|
||||
board = player[turn].run_cmd(str(num) + ' show_board')
|
||||
board = eval(board[board.index('['):board.index(']') + 1])
|
||||
for i in range(size):
|
||||
for j in range(size):
|
||||
print show[board[i * size + j]] + " ",
|
||||
for i in range(size[args.game]):
|
||||
for j in range(size[args.game]):
|
||||
print show[board[i * size[args.game] + j]] + " ",
|
||||
print "\n",
|
||||
data.boards.append(board)
|
||||
move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn] + '\n')
|
||||
print role[turn] + " : " + str(move),
|
||||
start_time = time.time()
|
||||
move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn])[:-1]
|
||||
print("\n" + role[turn] + " : " + str(move)),
|
||||
num += 1
|
||||
match = re.search(pattern, move)
|
||||
if match is not None:
|
||||
@ -126,33 +168,26 @@ if __name__ == '__main__':
|
||||
prob = prob.replace('],', ']')
|
||||
prob = eval(prob)
|
||||
data.probs.append(prob)
|
||||
score = player[turn].run_cmd(str(num) + ' get_score')
|
||||
print "Finished : ", score.split(" ")[1]
|
||||
# TODO: generalize the player
|
||||
score = player[0].run_cmd(str(num) + ' get_score')
|
||||
print("Finished : {}".format(score.split(" ")[1]))
|
||||
if eval(score.split(" ")[1]) > 0:
|
||||
data.winner = 1
|
||||
data.winner = utils.BLACK
|
||||
if eval(score.split(" ")[1]) < 0:
|
||||
data.winner = -1
|
||||
data.winner = utils.WHITE
|
||||
player[0].run_cmd(str(num) + ' clear_board')
|
||||
player[1].run_cmd(str(num) + ' clear_board')
|
||||
file_list = os.listdir(args.result_path)
|
||||
if not file_list:
|
||||
data_num = 0
|
||||
else:
|
||||
file_list.sort(key=lambda file: os.path.getmtime(args.result_path + file) if not os.path.isdir(
|
||||
args.result_path + file) else 0)
|
||||
data_num = eval(file_list[-1][:-4]) + 1
|
||||
with open("./data/" + str(data_num) + ".pkl", "wb") as file:
|
||||
file_list = os.listdir(args.data_path)
|
||||
current_time = strftime("%Y%m%d_%H%M%S", gmtime())
|
||||
if os.path.exists(args.data_path + current_time + ".pkl"):
|
||||
time.sleep(1)
|
||||
current_time = strftime("%Y%m%d_%H%M%S", gmtime())
|
||||
with open(args.data_path + current_time + ".pkl", "wb") as file:
|
||||
picklestring = cPickle.dump(data, file)
|
||||
data.reset()
|
||||
game_num += 1
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
subprocess.call(["kill", "-9", str(agent_v0.pid)])
|
||||
subprocess.call(["kill", "-9", str(agent_v1.pid)])
|
||||
print "Kill all player, finish all game."
|
||||
|
||||
subprocess.call(["kill", "-9", str(agent_v0.pid)])
|
||||
subprocess.call(["kill", "-9", str(agent_v1.pid)])
|
||||
print "Kill all player, finish all game."
|
||||
subprocess.call(["kill", "-9", str(black_player.pid)])
|
||||
subprocess.call(["kill", "-9", str(white_player.pid)])
|
||||
print("Kill all player, finish all game.")
|
||||
|
||||
@ -1,8 +1,5 @@
|
||||
import argparse
|
||||
import time
|
||||
import sys
|
||||
import Pyro4
|
||||
|
||||
from game import Game
|
||||
from engine import GTPEngine
|
||||
|
||||
@ -17,28 +14,29 @@ class Player(object):
|
||||
self.engine = kwargs['engine']
|
||||
|
||||
def run_cmd(self, command):
|
||||
#return "inside the Player of player.py"
|
||||
return self.engine.run_cmd(command)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None)
|
||||
parser.add_argument("--checkpoint_path", type=str, default="None")
|
||||
parser.add_argument("--role", type=str, default="unknown")
|
||||
parser.add_argument("--debug", type=str, default="False")
|
||||
parser.add_argument("--game", type=str, default="go")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.checkpoint_path == 'None':
|
||||
args.checkpoint_path = None
|
||||
game = Game(checkpoint_path=args.checkpoint_path)
|
||||
game = Game(name=args.game, role=args.role,
|
||||
checkpoint_path=args.checkpoint_path,
|
||||
debug=eval(args.debug))
|
||||
engine = GTPEngine(game_obj=game, name='tianshou', version=0)
|
||||
|
||||
daemon = Pyro4.Daemon() # make a Pyro daemon
|
||||
ns = Pyro4.locateNS() # find the name server
|
||||
player = Player(role=args.role, engine=engine)
|
||||
print "Init " + args.role + " player finished"
|
||||
print("Init " + args.role + " player finished")
|
||||
uri = daemon.register(player) # register the greeting maker as a Pyro object
|
||||
print "Start on name " + args.role
|
||||
ns.register(args.role, uri) # register the object with a name in the name server
|
||||
print "Start Request Loop " + str(uri)
|
||||
print("Start on name " + args.role)
|
||||
ns.register(args.role, uri) # register the object with a name in the name server
|
||||
print("Start requestLoop " + str(uri))
|
||||
daemon.requestLoop() # start the event loop of the server to wait for calls
|
||||
|
||||
|
||||
@ -1,123 +0,0 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
size = 9
|
||||
path = "/raid/tongzheng/tianshou/AlphaGo/data/part1/"
|
||||
save_path = "/raid/tongzheng/tianshou/AlphaGo/data/"
|
||||
name = os.listdir(path)
|
||||
print(len(name))
|
||||
batch_size = 128
|
||||
batch_num = 512
|
||||
|
||||
block_size = batch_size * batch_num
|
||||
slots_num = 16
|
||||
|
||||
|
||||
class block(object):
|
||||
def __init__(self, block_size, block_id):
|
||||
self.boards = []
|
||||
self.wins = []
|
||||
self.ps = []
|
||||
self.block_size = block_size
|
||||
self.block_id = block_id
|
||||
|
||||
def concat(self, board, p, win):
|
||||
board = board.reshape(-1, size, size, 17)
|
||||
win = win.reshape(-1, 1)
|
||||
p = p.reshape(-1, size ** 2 + 1)
|
||||
self.boards.append(board)
|
||||
self.wins.append(win)
|
||||
self.ps.append(p)
|
||||
|
||||
def isfull(self):
|
||||
assert len(self.boards) == len(self.wins)
|
||||
assert len(self.boards) == len(self.ps)
|
||||
return len(self.boards) == self.block_size
|
||||
|
||||
def save_and_reset(self, block_id):
|
||||
self.boards = np.concatenate(self.boards, axis=0)
|
||||
self.wins = np.concatenate(self.wins, axis=0)
|
||||
self.ps = np.concatenate(self.ps, axis=0)
|
||||
print ("Block {}, Boards shape {}, Wins Shape {}, Ps Shape {}".format(self.block_id, self.boards.shape[0],
|
||||
self.wins.shape[0], self.ps.shape[0]))
|
||||
np.savez(save_path + "block" + str(self.block_id), boards=self.boards, wins=self.wins, ps=self.ps)
|
||||
self.boards = []
|
||||
self.wins = []
|
||||
self.ps = []
|
||||
self.block_id = block_id
|
||||
|
||||
def store_num(self):
|
||||
assert len(self.boards) == len(self.wins)
|
||||
assert len(self.boards) == len(self.ps)
|
||||
return len(self.boards)
|
||||
|
||||
|
||||
def concat(block_list, board, win, p):
|
||||
global index
|
||||
seed = np.random.randint(slots_num)
|
||||
block_list[seed].concat(board, win, p)
|
||||
if block_list[seed].isfull():
|
||||
block_list[seed].save_and_reset(index)
|
||||
index = index + 1
|
||||
|
||||
|
||||
block_list = []
|
||||
for index in range(slots_num):
|
||||
block_list.append(block(block_size, index))
|
||||
index = index + 1
|
||||
for n in name:
|
||||
data = np.load(path + n)
|
||||
board = data["boards"]
|
||||
win = data["win"]
|
||||
p = data["p"]
|
||||
print("Start {}".format(n))
|
||||
print("Shape {}".format(board.shape[0]))
|
||||
start = -time.time()
|
||||
for i in range(board.shape[0]):
|
||||
board_ori = board[i].reshape(-1, size, size, 17)
|
||||
win_ori = win[i].reshape(-1, 1)
|
||||
p_ori = p[i].reshape(-1, size ** 2 + 1)
|
||||
concat(block_list, board_ori, p_ori, win_ori)
|
||||
|
||||
for t in range(1, 4):
|
||||
board_aug = np.rot90(board_ori, t, (1, 2))
|
||||
p_aug = np.concatenate(
|
||||
[np.rot90(p_ori[:, :-1].reshape(-1, size, size), t, (1, 2)).reshape(-1, size ** 2), p_ori[:, -1].reshape(-1, 1)],
|
||||
axis=1)
|
||||
concat(block_list, board_aug, p_aug, win_ori)
|
||||
|
||||
board_aug = board_ori[:, ::-1]
|
||||
p_aug = np.concatenate(
|
||||
[p_ori[:, :-1].reshape(-1, size, size)[:, ::-1].reshape(-1, size ** 2), p_ori[:, -1].reshape(-1, 1)],
|
||||
axis=1)
|
||||
concat(block_list, board_aug, p_aug, win_ori)
|
||||
|
||||
board_aug = board_ori[:, :, ::-1]
|
||||
p_aug = np.concatenate(
|
||||
[p_ori[:, :-1].reshape(-1, size, size)[:, :, ::-1].reshape(-1, size ** 2), p_ori[:, -1].reshape(-1, 1)],
|
||||
axis=1)
|
||||
concat(block_list, board_aug, p_aug, win_ori)
|
||||
|
||||
board_aug = np.rot90(board_ori[:, ::-1], 1, (1, 2))
|
||||
p_aug = np.concatenate(
|
||||
[np.rot90(p_ori[:, :-1].reshape(-1, size, size)[:, ::-1], 1, (1, 2)).reshape(-1, size ** 2),
|
||||
p_ori[:, -1].reshape(-1, 1)],
|
||||
axis=1)
|
||||
concat(block_list, board_aug, p_aug, win_ori)
|
||||
|
||||
board_aug = np.rot90(board_ori[:, :, ::-1], 1, (1, 2))
|
||||
p_aug = np.concatenate(
|
||||
[np.rot90(p_ori[:, :-1].reshape(-1, size, size)[:, :, ::-1], 1, (1, 2)).reshape(-1, size ** 2),
|
||||
p_ori[:, -1].reshape(-1, 1)],
|
||||
axis=1)
|
||||
concat(block_list, board_aug, p_aug, win_ori)
|
||||
print ("Finished {} with time {}".format(n, time.time() + start))
|
||||
data_num = 0
|
||||
for i in range(slots_num):
|
||||
print("Block {} ".format(block_list[i].block_id) + "Size {}".format(block_list[i].store_num()))
|
||||
data_num = data_num + block_list[i].store_num()
|
||||
print ("Total data {}".format(data_num))
|
||||
|
||||
for i in range(slots_num):
|
||||
block_list[i].save_and_reset(block_list[i].block_id)
|
||||
@ -1,155 +1,79 @@
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
import copy
|
||||
'''
|
||||
Settings of the Go game.
|
||||
Settings of the Reversi game.
|
||||
|
||||
(1, 1) is considered as the upper left corner of the board,
|
||||
(size, 1) is the lower left
|
||||
'''
|
||||
|
||||
|
||||
def find_correct_moves(own, enemy):
|
||||
"""return legal moves"""
|
||||
left_right_mask = 0x7e7e7e7e7e7e7e7e # Both most left-right edge are 0, else 1
|
||||
top_bottom_mask = 0x00ffffffffffff00 # Both most top-bottom edge are 0, else 1
|
||||
mask = left_right_mask & top_bottom_mask
|
||||
mobility = 0
|
||||
mobility |= search_offset_left(own, enemy, left_right_mask, 1) # Left
|
||||
mobility |= search_offset_left(own, enemy, mask, 9) # Left Top
|
||||
mobility |= search_offset_left(own, enemy, top_bottom_mask, 8) # Top
|
||||
mobility |= search_offset_left(own, enemy, mask, 7) # Top Right
|
||||
mobility |= search_offset_right(own, enemy, left_right_mask, 1) # Right
|
||||
mobility |= search_offset_right(own, enemy, mask, 9) # Bottom Right
|
||||
mobility |= search_offset_right(own, enemy, top_bottom_mask, 8) # Bottom
|
||||
mobility |= search_offset_right(own, enemy, mask, 7) # Left bottom
|
||||
return mobility
|
||||
|
||||
|
||||
def calc_flip(pos, own, enemy):
|
||||
"""return flip stones of enemy by bitboard when I place stone at pos.
|
||||
|
||||
:param pos: 0~63
|
||||
:param own: bitboard (0=top left, 63=bottom right)
|
||||
:param enemy: bitboard
|
||||
:return: flip stones of enemy when I place stone at pos.
|
||||
"""
|
||||
f1 = _calc_flip_half(pos, own, enemy)
|
||||
f2 = _calc_flip_half(63 - pos, rotate180(own), rotate180(enemy))
|
||||
return f1 | rotate180(f2)
|
||||
|
||||
|
||||
def _calc_flip_half(pos, own, enemy):
|
||||
el = [enemy, enemy & 0x7e7e7e7e7e7e7e7e, enemy & 0x7e7e7e7e7e7e7e7e, enemy & 0x7e7e7e7e7e7e7e7e]
|
||||
masks = [0x0101010101010100, 0x00000000000000fe, 0x0002040810204080, 0x8040201008040200]
|
||||
masks = [b64(m << pos) for m in masks]
|
||||
flipped = 0
|
||||
for e, mask in zip(el, masks):
|
||||
outflank = mask & ((e | ~mask) + 1) & own
|
||||
flipped |= (outflank - (outflank != 0)) & mask
|
||||
return flipped
|
||||
|
||||
|
||||
def search_offset_left(own, enemy, mask, offset):
|
||||
e = enemy & mask
|
||||
blank = ~(own | enemy)
|
||||
t = e & (own >> offset)
|
||||
t |= e & (t >> offset)
|
||||
t |= e & (t >> offset)
|
||||
t |= e & (t >> offset)
|
||||
t |= e & (t >> offset)
|
||||
t |= e & (t >> offset) # Up to six stones can be turned at once
|
||||
return blank & (t >> offset) # Only the blank squares can be started
|
||||
|
||||
|
||||
def search_offset_right(own, enemy, mask, offset):
|
||||
e = enemy & mask
|
||||
blank = ~(own | enemy)
|
||||
t = e & (own << offset)
|
||||
t |= e & (t << offset)
|
||||
t |= e & (t << offset)
|
||||
t |= e & (t << offset)
|
||||
t |= e & (t << offset)
|
||||
t |= e & (t << offset) # Up to six stones can be turned at once
|
||||
return blank & (t << offset) # Only the blank squares can be started
|
||||
|
||||
|
||||
def flip_vertical(x):
|
||||
k1 = 0x00FF00FF00FF00FF
|
||||
k2 = 0x0000FFFF0000FFFF
|
||||
x = ((x >> 8) & k1) | ((x & k1) << 8)
|
||||
x = ((x >> 16) & k2) | ((x & k2) << 16)
|
||||
x = (x >> 32) | b64(x << 32)
|
||||
return x
|
||||
|
||||
|
||||
def b64(x):
|
||||
return x & 0xFFFFFFFFFFFFFFFF
|
||||
|
||||
|
||||
def bit_count(x):
|
||||
return bin(x).count('1')
|
||||
|
||||
|
||||
def bit_to_array(x, size):
|
||||
"""bit_to_array(0b0010, 4) -> array([0, 1, 0, 0])"""
|
||||
return np.array(list(reversed((("0" * size) + bin(x)[2:])[-size:])), dtype=np.uint8)
|
||||
|
||||
|
||||
def flip_diag_a1h8(x):
|
||||
k1 = 0x5500550055005500
|
||||
k2 = 0x3333000033330000
|
||||
k4 = 0x0f0f0f0f00000000
|
||||
t = k4 & (x ^ b64(x << 28))
|
||||
x ^= t ^ (t >> 28)
|
||||
t = k2 & (x ^ b64(x << 14))
|
||||
x ^= t ^ (t >> 14)
|
||||
t = k1 & (x ^ b64(x << 7))
|
||||
x ^= t ^ (t >> 7)
|
||||
return x
|
||||
|
||||
|
||||
def rotate90(x):
|
||||
return flip_diag_a1h8(flip_vertical(x))
|
||||
|
||||
|
||||
def rotate180(x):
|
||||
return rotate90(rotate90(x))
|
||||
|
||||
|
||||
class Reversi:
|
||||
def __init__(self, black=None, white=None):
|
||||
self.black = black or (0b00001000 << 24 | 0b00010000 << 32)
|
||||
self.white = white or (0b00010000 << 24 | 0b00001000 << 32)
|
||||
self.board = None # 8 * 8 board with 1 for black, -1 for white and 0 for blank
|
||||
self.color = None # 1 for black and -1 for white
|
||||
self.action = None # number in 0~63
|
||||
self.winner = None
|
||||
self.black_win = None
|
||||
self.size = 8
|
||||
def __init__(self, **kwargs):
|
||||
self.size = kwargs['size']
|
||||
|
||||
def get_board(self, black=None, white=None):
|
||||
self.black = black or (0b00001000 << 24 | 0b00010000 << 32)
|
||||
self.white = white or (0b00010000 << 24 | 0b00001000 << 32)
|
||||
self.board = self.bitboard2board()
|
||||
return self.board
|
||||
def _deflatten(self, idx):
|
||||
x = idx // self.size + 1
|
||||
y = idx % self.size + 1
|
||||
return (x, y)
|
||||
|
||||
def is_valid(self, is_next=False):
|
||||
self.board2bitboard()
|
||||
own, enemy = self.get_own_and_enemy(is_next)
|
||||
mobility = find_correct_moves(own, enemy)
|
||||
valid_moves = bit_to_array(mobility, 64)
|
||||
valid_moves = np.argwhere(valid_moves)
|
||||
valid_moves = list(np.reshape(valid_moves, len(valid_moves)))
|
||||
return valid_moves
|
||||
def _flatten(self, vertex):
|
||||
x, y = vertex
|
||||
if (x == 0) and (y == 0):
|
||||
return self.size ** 2
|
||||
return (x - 1) * self.size + (y - 1)
|
||||
|
||||
def get_board(self):
|
||||
board = np.zeros([self.size, self.size], dtype=np.int32)
|
||||
board[self.size / 2 - 1, self.size / 2 - 1] = -1
|
||||
board[self.size / 2, self.size / 2] = -1
|
||||
board[self.size / 2 - 1, self.size / 2] = 1
|
||||
board[self.size / 2, self.size / 2 - 1] = 1
|
||||
return board
|
||||
|
||||
def _find_correct_moves(self, board, color, is_next=False):
|
||||
moves = []
|
||||
if is_next:
|
||||
new_color = 0 - color
|
||||
else:
|
||||
new_color = color
|
||||
for i in range(self.size ** 2):
|
||||
x, y = self._deflatten(i)
|
||||
valid = self._is_valid(board, x - 1, y - 1, new_color)
|
||||
if valid:
|
||||
moves.append(i)
|
||||
return moves
|
||||
|
||||
def _one_direction_valid(self, board, x, y, color):
|
||||
if (x >= 0) and (x < self.size):
|
||||
if (y >= 0) and (y < self.size):
|
||||
if board[x, y] == color:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_valid(self, board, x, y, color):
|
||||
if board[x, y]:
|
||||
return False
|
||||
for x_direction in [-1, 0, 1]:
|
||||
for y_direction in [-1, 0, 1]:
|
||||
new_x = x
|
||||
new_y = y
|
||||
flag = 0
|
||||
while True:
|
||||
new_x += x_direction
|
||||
new_y += y_direction
|
||||
if self._one_direction_valid(board, new_x, new_y, 0 - color):
|
||||
flag = 1
|
||||
else:
|
||||
break
|
||||
if self._one_direction_valid(board, new_x, new_y, color) and flag:
|
||||
return True
|
||||
return False
|
||||
|
||||
def simulate_get_mask(self, state, action_set):
|
||||
history_boards, color = state
|
||||
board = history_boards[-1]
|
||||
self.board = board
|
||||
self.color = color
|
||||
valid_moves = self.is_valid()
|
||||
# TODO it seems that the pass move is not considered
|
||||
history_boards, color = copy.deepcopy(state)
|
||||
board = copy.deepcopy(history_boards[-1])
|
||||
valid_moves = self._find_correct_moves(board, color)
|
||||
if not len(valid_moves):
|
||||
invalid_action_mask = action_set[0:-1]
|
||||
else:
|
||||
@ -160,144 +84,115 @@ class Reversi:
|
||||
return invalid_action_mask
|
||||
|
||||
def simulate_step_forward(self, state, action):
|
||||
self.board = state[0]
|
||||
self.color = state[1]
|
||||
self.board2bitboard()
|
||||
self.action = action
|
||||
if self.action == 64:
|
||||
valid_moves = self.is_valid(is_next=True)
|
||||
history_boards, color = copy.deepcopy(state)
|
||||
board = copy.deepcopy(history_boards[-1])
|
||||
if action == self.size ** 2:
|
||||
valid_moves = self._find_correct_moves(board, color, is_next=True)
|
||||
if not len(valid_moves):
|
||||
self._game_over()
|
||||
return None, self.winner * self.color
|
||||
winner = self._get_winner(board)
|
||||
return None, winner * color
|
||||
else:
|
||||
return [self.board, 0 - self.color], 0
|
||||
self.step()
|
||||
new_board = self.bitboard2board()
|
||||
return [new_board, 0 - self.color], 0
|
||||
return [history_boards, 0 - color], 0
|
||||
new_board = self._step(board, color, action)
|
||||
history_boards.append(new_board)
|
||||
return [history_boards, 0 - color], 0
|
||||
|
||||
def executor_do_move(self, board, color, vertex):
|
||||
self.board = board
|
||||
self.color = color
|
||||
self.board2bitboard()
|
||||
self.action = self._flatten(vertex)
|
||||
if self.action == 64:
|
||||
valid_moves = self.is_valid(is_next=True)
|
||||
def simulate_hashable_conversion(self, state):
|
||||
# since go is MDP, we only need the last board for hashing
|
||||
return tuple(state[0][-1].flatten().tolist())
|
||||
|
||||
def _get_winner(self, board):
|
||||
black_num, white_num = self._number_of_black_and_white(board)
|
||||
black_win = black_num - white_num
|
||||
if black_win > 0:
|
||||
winner = 1
|
||||
elif black_win < 0:
|
||||
winner = -1
|
||||
else:
|
||||
winner = 0
|
||||
return winner
|
||||
|
||||
def _number_of_black_and_white(self, board):
|
||||
black_num = 0
|
||||
white_num = 0
|
||||
board_list = np.reshape(board, self.size ** 2)
|
||||
for i in range(len(board_list)):
|
||||
if board_list[i] == 1:
|
||||
black_num += 1
|
||||
elif board_list[i] == -1:
|
||||
white_num += 1
|
||||
return black_num, white_num
|
||||
|
||||
def _step(self, board, color, action):
|
||||
if action < 0 or action > self.size ** 2 - 1:
|
||||
raise ValueError("Action not in the range of [0,63]!")
|
||||
if action is None:
|
||||
raise ValueError("Action is None!")
|
||||
x, y = self._deflatten(action)
|
||||
new_board = self._flip(board, x - 1, y - 1, color)
|
||||
return new_board
|
||||
|
||||
def _flip(self, board, x, y, color):
|
||||
valid = 0
|
||||
board[x, y] = color
|
||||
for x_direction in [-1, 0, 1]:
|
||||
for y_direction in [-1, 0, 1]:
|
||||
new_x = x
|
||||
new_y = y
|
||||
flag = 0
|
||||
while True:
|
||||
new_x += x_direction
|
||||
new_y += y_direction
|
||||
if self._one_direction_valid(board, new_x, new_y, 0 - color):
|
||||
flag = 1
|
||||
else:
|
||||
break
|
||||
if self._one_direction_valid(board, new_x, new_y, color) and flag:
|
||||
valid = 1
|
||||
flip_x = x
|
||||
flip_y = y
|
||||
while True:
|
||||
flip_x += x_direction
|
||||
flip_y += y_direction
|
||||
if self._one_direction_valid(board, flip_x, flip_y, 0 - color):
|
||||
board[flip_x, flip_y] = color
|
||||
else:
|
||||
break
|
||||
if valid:
|
||||
return board
|
||||
else:
|
||||
raise ValueError("Invalid action")
|
||||
|
||||
def executor_do_move(self, history, latest_boards, board, color, vertex):
|
||||
board = np.reshape(board, (self.size, self.size))
|
||||
color = color
|
||||
action = self._flatten(vertex)
|
||||
if action == self.size ** 2:
|
||||
valid_moves = self._find_correct_moves(board, color, is_next=True)
|
||||
if not len(valid_moves):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
self.step()
|
||||
new_board = self.bitboard2board()
|
||||
for i in range(64):
|
||||
board[i] = new_board[i]
|
||||
new_board = self._step(board, color, action)
|
||||
history.append(new_board)
|
||||
latest_boards.append(new_board)
|
||||
return True
|
||||
|
||||
def executor_get_score(self, board):
|
||||
self.board = board
|
||||
self._game_over()
|
||||
if self.black_win is not None:
|
||||
return self.black_win
|
||||
else:
|
||||
raise ValueError("Game not finished!")
|
||||
board = board
|
||||
winner = self._get_winner(board)
|
||||
return winner
|
||||
|
||||
def board2bitboard(self):
|
||||
count = 1
|
||||
if self.board is None:
|
||||
raise ValueError("None board!")
|
||||
self.black = 0
|
||||
self.white = 0
|
||||
for i in range(64):
|
||||
if self.board[i] == 1:
|
||||
self.black |= count
|
||||
elif self.board[i] == -1:
|
||||
self.white |= count
|
||||
count *= 2
|
||||
'''
|
||||
def vertex2action(self, vertex):
|
||||
x, y = vertex
|
||||
if x == 0 and y == 0:
|
||||
self.action = None
|
||||
else:
|
||||
self.action = 8 * (x - 1) + y - 1
|
||||
'''
|
||||
|
||||
def bitboard2board(self):
|
||||
board = []
|
||||
black = bit_to_array(self.black, 64)
|
||||
white = bit_to_array(self.white, 64)
|
||||
for i in range(64):
|
||||
if black[i]:
|
||||
board.append(1)
|
||||
elif white[i]:
|
||||
board.append(-1)
|
||||
else:
|
||||
board.append(0)
|
||||
return board
|
||||
if __name__ == "__main__":
|
||||
reversi = Reversi()
|
||||
# board = reversi.get_board()
|
||||
# print(board)
|
||||
# state, value = reversi.simulate_step_forward([board, -1], 20)
|
||||
# print(state[0])
|
||||
# print("board")
|
||||
# print(board)
|
||||
# r = reversi.executor_get_score(board)
|
||||
# print(r)
|
||||
|
||||
def step(self):
|
||||
if self.action < 0 or self.action > 63:
|
||||
raise ValueError("Action not in the range of [0,63]!")
|
||||
if self.action is None:
|
||||
raise ValueError("Action is None!")
|
||||
|
||||
own, enemy = self.get_own_and_enemy()
|
||||
|
||||
flipped = calc_flip(self.action, own, enemy)
|
||||
if bit_count(flipped) == 0:
|
||||
# self.illegal_move_to_lose(self.action)
|
||||
raise ValueError("Illegal action!")
|
||||
own ^= flipped
|
||||
own |= 1 << self.action
|
||||
enemy ^= flipped
|
||||
self.set_own_and_enemy(own, enemy)
|
||||
|
||||
def _game_over(self):
|
||||
# self.done = True
|
||||
|
||||
if self.winner is None:
|
||||
black_num, white_num = self.number_of_black_and_white
|
||||
self.black_win = black_num - white_num
|
||||
if self.black_win > 0:
|
||||
self.winner = 1
|
||||
elif self.black_win < 0:
|
||||
self.winner = -1
|
||||
else:
|
||||
self.winner = 0
|
||||
|
||||
def illegal_move_to_lose(self, action):
|
||||
self._game_over()
|
||||
|
||||
def get_own_and_enemy(self, is_next=False):
|
||||
if is_next:
|
||||
color = 0 - self.color
|
||||
else:
|
||||
color = self.color
|
||||
if color == 1:
|
||||
own, enemy = self.black, self.white
|
||||
elif color == -1:
|
||||
own, enemy = self.white, self.black
|
||||
else:
|
||||
own, enemy = None, None
|
||||
return own, enemy
|
||||
|
||||
def set_own_and_enemy(self, own, enemy):
|
||||
if self.color == 1:
|
||||
self.black, self.white = own, enemy
|
||||
else:
|
||||
self.white, self.black = own, enemy
|
||||
|
||||
def _deflatten(self, idx):
|
||||
x = idx // self.size + 1
|
||||
y = idx % self.size + 1
|
||||
return (x, y)
|
||||
|
||||
def _flatten(self, vertex):
|
||||
x, y = vertex
|
||||
if (x == 0) and (y == 0):
|
||||
return 64
|
||||
return (x - 1) * self.size + (y - 1)
|
||||
|
||||
@property
|
||||
def number_of_black_and_white(self):
|
||||
return bit_count(self.black), bit_count(self.white)
|
||||
|
||||
@ -1,103 +0,0 @@
|
||||
from game import Game
|
||||
from engine import GTPEngine
|
||||
import re
|
||||
import numpy as np
|
||||
import os
|
||||
from collections import deque
|
||||
import utils
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--result_path', type=str, default='./part1')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.result_path):
|
||||
os.makedirs(args.result_path)
|
||||
|
||||
game = Game()
|
||||
engine = GTPEngine(game_obj=game)
|
||||
history = deque(maxlen=8)
|
||||
for i in range(8):
|
||||
history.append(game.board)
|
||||
state = []
|
||||
prob = []
|
||||
winner = []
|
||||
pattern = "[A-Z]{1}[0-9]{1}"
|
||||
game.show_board()
|
||||
|
||||
|
||||
def history2state(history, color):
|
||||
state = np.zeros([1, game.size, game.size, 17])
|
||||
for i in range(8):
|
||||
state[0, :, :, i] = np.array(np.array(history[i]) == np.ones(game.size ** 2)).reshape(game.size, game.size)
|
||||
state[0, :, :, i + 8] = np.array(np.array(history[i]) == -np.ones(game.size ** 2)).reshape(game.size, game.size)
|
||||
if color == utils.BLACK:
|
||||
state[0, :, :, 16] = np.ones([game.size, game.size])
|
||||
if color == utils.WHITE:
|
||||
state[0, :, :, 16] = np.zeros([game.size, game.size])
|
||||
return state
|
||||
|
||||
|
||||
num = 0
|
||||
game_num = 0
|
||||
black_pass = False
|
||||
white_pass = False
|
||||
while True:
|
||||
print("Start game {}".format(game_num))
|
||||
while not (black_pass and white_pass) and num < game.size ** 2 * 2:
|
||||
if num % 2 == 0:
|
||||
color = utils.BLACK
|
||||
new_state = history2state(history, color)
|
||||
state.append(new_state)
|
||||
result = engine.run_cmd(str(num) + " genmove BLACK")
|
||||
num += 1
|
||||
match = re.search(pattern, result)
|
||||
if match is not None:
|
||||
print(match.group())
|
||||
else:
|
||||
print("pass")
|
||||
if re.search("pass", result) is not None:
|
||||
black_pass = True
|
||||
else:
|
||||
black_pass = False
|
||||
else:
|
||||
color = utils.WHITE
|
||||
new_state = history2state(history, color)
|
||||
state.append(new_state)
|
||||
result = engine.run_cmd(str(num) + " genmove WHITE")
|
||||
num += 1
|
||||
match = re.search(pattern, result)
|
||||
if match is not None:
|
||||
print(match.group())
|
||||
else:
|
||||
print("pass")
|
||||
if re.search("pass", result) is not None:
|
||||
white_pass = True
|
||||
else:
|
||||
white_pass = False
|
||||
game.show_board()
|
||||
prob.append(np.array(game.prob).reshape(-1, game.size ** 2 + 1))
|
||||
print("Finished")
|
||||
print("\n")
|
||||
score = game.game_engine.executor_get_score(game.board)
|
||||
if score > 0:
|
||||
winner = utils.BLACK
|
||||
else:
|
||||
winner = utils.WHITE
|
||||
state = np.concatenate(state, axis=0)
|
||||
prob = np.concatenate(prob, axis=0)
|
||||
winner = np.ones([num, 1]) * winner
|
||||
assert state.shape[0] == prob.shape[0]
|
||||
assert state.shape[0] == winner.shape[0]
|
||||
np.savez(args.result_path + "/game" + str(game_num), state=state, prob=prob, winner=winner)
|
||||
state = []
|
||||
prob = []
|
||||
winner = []
|
||||
num = 0
|
||||
black_pass = False
|
||||
white_pass = False
|
||||
engine.run_cmd(str(num) + " clear_board")
|
||||
history.clear()
|
||||
for _ in range(8):
|
||||
history.append(game.board)
|
||||
game_num += 1
|
||||
@ -4,21 +4,6 @@ import time
|
||||
|
||||
c_puct = 5
|
||||
|
||||
|
||||
def list2tuple(list):
|
||||
try:
|
||||
return tuple(list2tuple(sub) for sub in list)
|
||||
except TypeError:
|
||||
return list
|
||||
|
||||
|
||||
def tuple2list(tuple):
|
||||
try:
|
||||
return list(tuple2list(sub) for sub in tuple)
|
||||
except TypeError:
|
||||
return tuple
|
||||
|
||||
|
||||
class MCTSNode(object):
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||
self.parent = parent
|
||||
@ -38,23 +23,29 @@ class MCTSNode(object):
|
||||
def valid_mask(self, simulator):
|
||||
pass
|
||||
|
||||
|
||||
class UCTNode(MCTSNode):
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||
def __init__(self, parent, action, state, action_num, prior, mcts, inverse=False):
|
||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||
self.Q = np.zeros([action_num])
|
||||
self.W = np.zeros([action_num])
|
||||
self.N = np.zeros([action_num])
|
||||
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
||||
self.c_puct = c_puct
|
||||
self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
||||
self.mask = None
|
||||
self.elapse_time = 0
|
||||
self.mcts = mcts
|
||||
|
||||
def selection(self, simulator):
|
||||
head = time.time()
|
||||
self.valid_mask(simulator)
|
||||
self.mcts.valid_mask_time += time.time() - head
|
||||
action = np.argmax(self.ucb)
|
||||
if action in self.children.keys():
|
||||
self.mcts.state_selection_time += time.time() - head
|
||||
return self.children[action].selection(simulator)
|
||||
else:
|
||||
self.children[action] = ActionNode(self, action)
|
||||
self.children[action] = ActionNode(self, action, mcts=self.mcts)
|
||||
self.mcts.state_selection_time += time.time() - head
|
||||
return self.children[action].selection(simulator)
|
||||
|
||||
def backpropagation(self, action):
|
||||
@ -80,7 +71,7 @@ class UCTNode(MCTSNode):
|
||||
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
|
||||
self.ucb[self.mask] = -float("Inf")
|
||||
|
||||
|
||||
# Code reserved for Thompson Sampling
|
||||
class TSNode(MCTSNode):
|
||||
def __init__(self, parent, action, state, action_num, prior, method="Gaussian", inverse=False):
|
||||
super(TSNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||
@ -93,68 +84,68 @@ class TSNode(MCTSNode):
|
||||
|
||||
|
||||
class ActionNode(object):
|
||||
def __init__(self, parent, action):
|
||||
def __init__(self, parent, action, mcts):
|
||||
self.parent = parent
|
||||
self.action = action
|
||||
self.children = {}
|
||||
self.next_state = None
|
||||
self.origin_state = None
|
||||
self.next_state_hashable = None
|
||||
self.state_type = None
|
||||
self.reward = 0
|
||||
|
||||
def type_conversion_to_tuple(self):
|
||||
if type(self.next_state) is np.ndarray:
|
||||
self.next_state = self.next_state.tolist()
|
||||
if type(self.next_state) is list:
|
||||
self.next_state = list2tuple(self.next_state)
|
||||
|
||||
def type_conversion_to_origin(self):
|
||||
if self.state_type is np.ndarray:
|
||||
self.next_state = np.array(self.next_state)
|
||||
if self.state_type is list:
|
||||
self.next_state = tuple2list(self.next_state)
|
||||
self.mcts = mcts
|
||||
|
||||
def selection(self, simulator):
|
||||
head = time.time()
|
||||
self.next_state, self.reward = simulator.simulate_step_forward(self.parent.state, self.action)
|
||||
self.origin_state = self.next_state
|
||||
self.state_type = type(self.next_state)
|
||||
self.type_conversion_to_tuple()
|
||||
if self.next_state is not None:
|
||||
if self.next_state in self.children.keys():
|
||||
return self.children[self.next_state].selection(simulator)
|
||||
else:
|
||||
return self.parent, self.action
|
||||
else:
|
||||
return self.parent, self.action
|
||||
self.mcts.simulate_sf_time += time.time() - head
|
||||
if self.next_state is None: # next_state is None means that self.parent.state is the terminate state
|
||||
self.mcts.action_selection_time += time.time() - head
|
||||
return self
|
||||
head = time.time()
|
||||
self.next_state_hashable = simulator.simulate_hashable_conversion(self.next_state)
|
||||
self.mcts.hash_time += time.time() - head
|
||||
if self.next_state_hashable in self.children.keys(): # next state has already visited before
|
||||
self.mcts.action_selection_time += time.time() - head
|
||||
return self.children[self.next_state_hashable].selection(simulator)
|
||||
else: # next state is a new state never seen before
|
||||
self.mcts.action_selection_time += time.time() - head
|
||||
return self
|
||||
|
||||
def expansion(self, evaluator, action_num):
|
||||
if self.next_state is not None:
|
||||
prior, value = evaluator(self.next_state)
|
||||
self.children[self.next_state] = UCTNode(self, self.action, self.origin_state, action_num, prior,
|
||||
self.parent.inverse)
|
||||
return value
|
||||
else:
|
||||
return 0.
|
||||
def expansion(self, prior, action_num):
|
||||
self.children[self.next_state_hashable] = UCTNode(self, self.action, self.next_state, action_num, prior,
|
||||
mcts=self.mcts, inverse=self.parent.inverse)
|
||||
|
||||
def backpropagation(self, value):
|
||||
self.reward += value
|
||||
self.parent.backpropagation(self.action)
|
||||
|
||||
|
||||
class MCTS(object):
|
||||
def __init__(self, simulator, evaluator, root, action_num, method="UCT", inverse=False):
|
||||
def __init__(self, simulator, evaluator, start_state, action_num, method="UCT",
|
||||
role="unknown", debug=False, inverse=False):
|
||||
self.simulator = simulator
|
||||
self.evaluator = evaluator
|
||||
prior, _ = self.evaluator(root)
|
||||
self.role = role
|
||||
self.debug = debug
|
||||
prior, _ = self.evaluator(start_state)
|
||||
self.action_num = action_num
|
||||
if method == "":
|
||||
self.root = root
|
||||
self.root = start_state
|
||||
if method == "UCT":
|
||||
self.root = UCTNode(None, None, root, action_num, prior, inverse=inverse)
|
||||
self.root = UCTNode(None, None, start_state, action_num, prior, mcts=self, inverse=inverse)
|
||||
if method == "TS":
|
||||
self.root = TSNode(None, None, root, action_num, prior, inverse=inverse)
|
||||
self.root = TSNode(None, None, start_state, action_num, prior, inverse=inverse)
|
||||
self.inverse = inverse
|
||||
|
||||
# time spend on each step
|
||||
self.selection_time = 0
|
||||
self.expansion_time = 0
|
||||
self.backpropagation_time = 0
|
||||
self.action_selection_time = 0
|
||||
self.state_selection_time = 0
|
||||
self.simulate_sf_time = 0
|
||||
self.valid_mask_time = 0
|
||||
self.hash_time = 0
|
||||
|
||||
def search(self, max_step=None, max_time=None):
|
||||
step = 0
|
||||
start_time = time.time()
|
||||
@ -166,13 +157,42 @@ class MCTS(object):
|
||||
raise ValueError("Need a stop criteria!")
|
||||
|
||||
while step < max_step and time.time() - start_time < max_step:
|
||||
self._expand()
|
||||
sel_time, exp_time, back_time = self._expand()
|
||||
self.selection_time += sel_time
|
||||
self.expansion_time += exp_time
|
||||
self.backpropagation_time += back_time
|
||||
step += 1
|
||||
if self.debug:
|
||||
file = open("mcts_profiling.log", "a")
|
||||
file.write("[" + str(self.role) + "]"
|
||||
+ " sel " + '%.3f' % self.selection_time + " "
|
||||
+ " sel_sta " + '%.3f' % self.state_selection_time + " "
|
||||
+ " valid " + '%.3f' % self.valid_mask_time + " "
|
||||
+ " sel_act " + '%.3f' % self.action_selection_time + " "
|
||||
+ " hash " + '%.3f' % self.hash_time + " "
|
||||
+ " step forward " + '%.3f' % self.simulate_sf_time + " "
|
||||
+ " expansion " + '%.3f' % self.expansion_time + " "
|
||||
+ " backprop " + '%.3f' % self.backpropagation_time + " "
|
||||
+ "\n")
|
||||
file.close()
|
||||
|
||||
def _expand(self):
|
||||
node, new_action = self.root.selection(self.simulator)
|
||||
value = node.children[new_action].expansion(self.evaluator, self.action_num)
|
||||
node.children[new_action].backpropagation(value + 0.)
|
||||
t0 = time.time()
|
||||
next_action = self.root.selection(self.simulator)
|
||||
t1 = time.time()
|
||||
# next_action.next_state is None means the parent state node of next_action is a terminate node
|
||||
if next_action.next_state is not None:
|
||||
prior, value = self.evaluator(next_action.next_state)
|
||||
next_action.expansion(prior, self.action_num)
|
||||
else:
|
||||
value = 0
|
||||
t2 = time.time()
|
||||
if self.inverse:
|
||||
next_action.backpropagation(-value + 0.)
|
||||
else:
|
||||
next_action.backpropagation(value + 0.)
|
||||
t3 = time.time()
|
||||
return t1 - t0, t2 - t1, t3 - t2
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
||||
@ -12,6 +12,9 @@ class TestEnv:
|
||||
print(self.reward)
|
||||
# print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1]))
|
||||
|
||||
def simulate_is_valid(self, state, act):
|
||||
return True
|
||||
|
||||
def step_forward(self, state, action):
|
||||
if action != 0 and action != 1:
|
||||
raise ValueError("Action must be 0 or 1! Your action is {}".format(action))
|
||||
|
||||
301
tianshou/core/mcts/mcts_virtual_loss.py
Normal file
301
tianshou/core/mcts/mcts_virtual_loss.py
Normal file
@ -0,0 +1,301 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim:fenc=utf-8
|
||||
# $File: mcts_virtual_loss.py
|
||||
# $Date: Sun Dec 24 16:4740 2017 +0800
|
||||
# Original file: mcts.py
|
||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||
#
|
||||
|
||||
"""
|
||||
This is an implementation of the MCTS with virtual loss.
|
||||
Due to the limitation of Python design mechanism, we implements the virtual loss in a mini-batch
|
||||
manner.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
import sys,os
|
||||
from .utils import list2tuple, tuple2list
|
||||
|
||||
|
||||
class MCTSNodeVirtualLoss(object):
|
||||
"""
|
||||
MCTS abstract class with virtual loss. Currently we only support UCT node.
|
||||
Role of the Parameters can be found in Readme.md.
|
||||
"""
|
||||
def __init__(self,
|
||||
parent,
|
||||
action,
|
||||
state,
|
||||
action_num,
|
||||
prior,
|
||||
inverse = False):
|
||||
self.parent = parent
|
||||
self.action = action
|
||||
self.children = {}
|
||||
self.state = state
|
||||
self.action_num = action_num
|
||||
self.prior = np.array(prior).reshape(-1)
|
||||
self.inverse = inverse
|
||||
|
||||
def selection(self, simulator):
|
||||
raise NotImplementedError("Need to implement function selection")
|
||||
|
||||
def backpropagation(self, action):
|
||||
raise NotImplementedError("Need to implement function backpropagation")
|
||||
|
||||
def valid_mask(self, simulator):
|
||||
pass
|
||||
|
||||
class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
|
||||
"""
|
||||
UCT node (state node) with virtual loss.
|
||||
Role of the Parameters can be found in Readme.md.
|
||||
:param c_puct balance between exploration and exploition,
|
||||
"""
|
||||
def __init__(self,
|
||||
parent,
|
||||
action,
|
||||
state,
|
||||
action_num,
|
||||
prior,
|
||||
inverse=False,
|
||||
c_puct = 5):
|
||||
super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||
self.Q = np.zeros([action_num])
|
||||
self.W = np.zeros([action_num])
|
||||
self.N = np.zeros([action_num])
|
||||
self.virtual_loss = np.zeros([action_num])
|
||||
self.c_puct = c_puct
|
||||
#### modified by adding virtual loss
|
||||
#self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
||||
|
||||
self.mask = None
|
||||
|
||||
def selection(self,
|
||||
simulator):
|
||||
self.valid_mask(simulator)
|
||||
self.Q = np.zeros([self.action_num])
|
||||
N_not_zero = (self.N + self.virtual_loss) > 0
|
||||
self.Q[N_not_zero] = (self.W[N_not_zero] + 0.)/ (self.virtual_loss[N_not_zero] + self.N[N_not_zero])
|
||||
self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N + self.virtual_loss)) /\
|
||||
(self.N + self.virtual_loss + 1)
|
||||
action = np.argmax(self.ucb)
|
||||
self.virtual_loss[action] += 1
|
||||
|
||||
if action in self.children.keys():
|
||||
return self.children[action].selection(simulator)
|
||||
else:
|
||||
self.children[action] = ActionNodeVirtualLoss(self, action)
|
||||
return self.children[action].selection(simulator)
|
||||
|
||||
def remove_virtual_loss(self):
|
||||
### if not virtual_loss for every action is zero
|
||||
if np.sum(self.virtual_loss > 0) > 0:
|
||||
self.virtual_loss = np.zeros([self.action_num])
|
||||
if self.parent:
|
||||
self.parent.remove_virtual_loss()
|
||||
|
||||
def backpropagation(self, action):
|
||||
action = int(action)
|
||||
self.N[action] += 1
|
||||
self.W[action] += self.children[action].reward
|
||||
|
||||
## do not need to compute Q and ucb immediately since it will be modified by virtual loss
|
||||
## just comment out and leaving for comparision
|
||||
#for i in range(self.action_num):
|
||||
# if self.N[i] != 0:
|
||||
# self.Q[i] = (self.W[i] + 0.) / self.N[i]
|
||||
#self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1.)
|
||||
|
||||
if self.parent is not None:
|
||||
if self.inverse:
|
||||
self.parent.backpropagation(-self.children[action].reward)
|
||||
else:
|
||||
self.parent.backpropagation(self.children[action].reward)
|
||||
|
||||
def valid_mask(self, simulator):
|
||||
if self.mask is None:
|
||||
start_time = time.time()
|
||||
self.mask = []
|
||||
for act in range(self.action_num - 1):
|
||||
if not simulator.simulate_is_valid(self.state, act):
|
||||
self.mask.append(act)
|
||||
self.ucb[act] = -float("Inf")
|
||||
else:
|
||||
self.ucb[self.mask] = -float("Inf")
|
||||
|
||||
|
||||
|
||||
class ActionNodeVirtualLoss(object):
|
||||
"""
|
||||
Action node with virtual loss.
|
||||
"""
|
||||
def __init__(self, parent, action):
|
||||
self.parent = parent
|
||||
self.action = action
|
||||
self.children = {}
|
||||
self.next_state = None
|
||||
self.origin_state = None
|
||||
self.state_type = None
|
||||
self.reward = 0
|
||||
|
||||
def remove_virtual_loss(self):
|
||||
self.parent.remove_virtual_loss()
|
||||
|
||||
def type_conversion_to_tuple(self):
|
||||
if type(self.next_state) is np.ndarray:
|
||||
self.next_state = self.next_state.tolist()
|
||||
if type(self.next_state) is list:
|
||||
self.next_state = list2tuple(self.next_state)
|
||||
|
||||
def type_conversion_to_origin(self):
|
||||
if self.state_type is np.ndarray:
|
||||
self.next_state = np.array(self.next_state)
|
||||
if self.state_type is list:
|
||||
self.next_state = tuple2list(self.next_state)
|
||||
|
||||
def selection(self, simulator):
|
||||
self.next_state, self.reward = simulator.step_forward(self.parent.state, self.action)
|
||||
self.origin_state = self.next_state
|
||||
self.state_type = type(self.next_state)
|
||||
self.type_conversion_to_tuple()
|
||||
if self.next_state is not None:
|
||||
if self.next_state in self.children.keys():
|
||||
return self.children[self.next_state].selection(simulator)
|
||||
else:
|
||||
return self.parent, self.action
|
||||
else:
|
||||
return self.parent, self.action
|
||||
|
||||
def expansion(self, action, state, action_num, prior, inverse ):
|
||||
if state is not None:
|
||||
self.children[state] = UCTNodeVirtualLoss(self, action, state, action_num, prior, inverse)
|
||||
|
||||
|
||||
def backpropagation(self, value):
|
||||
self.reward += value
|
||||
self.parent.backpropagation(self.action)
|
||||
|
||||
|
||||
class MCTSVirtualLoss(object):
|
||||
"""
|
||||
MCTS class with virtual loss
|
||||
"""
|
||||
def __init__(self, simulator, evaluator, root, action_num, batch_size = 1, method = "UCT", inverse = False):
|
||||
self.simulator = simulator
|
||||
self.evaluator = evaluator
|
||||
prior, _ = self.evaluator(root)
|
||||
self.action_num = action_num
|
||||
self.batch_size = batch_size
|
||||
|
||||
if method == "":
|
||||
self.root = root
|
||||
elif method == "UCT":
|
||||
self.root = UCTNodeVirtualLoss(None, None, root, action_num, prior, inverse)
|
||||
elif method == "TS":
|
||||
self.root = TSNodeVirtualLoss(None, None, root, action_num, prior, inverse=inverse)
|
||||
else:
|
||||
raise ValueError("Need a root type")
|
||||
|
||||
self.inverse = inverse
|
||||
|
||||
|
||||
def do_search(self, max_step=None, max_time=None):
|
||||
"""
|
||||
Expand the MCTS tree with stop crierion either by max_step or max_time
|
||||
|
||||
:param max_step search maximum minibath rounds. ONE step is ONE minibatch
|
||||
:param max_time search maximum seconds
|
||||
"""
|
||||
if max_step is not None:
|
||||
self.step = 0
|
||||
self.max_step = max_step
|
||||
if max_time is not None:
|
||||
self.start_time = time.time()
|
||||
self.max_time = max_time
|
||||
if max_step is None and max_time is None:
|
||||
raise ValueError("Need a stop criteria!")
|
||||
|
||||
self.select_time = []
|
||||
self.evaluate_time = []
|
||||
self.bp_time = []
|
||||
while (max_step is not None and self.step < self.max_step or max_step is None) \
|
||||
and (max_time is not None and time.time() - self.start_time < self.max_time or max_time is None):
|
||||
self._expand()
|
||||
if max_step is not None:
|
||||
self.step += 1
|
||||
|
||||
def _expand(self):
|
||||
"""
|
||||
Core logic method for MCTS tree to expand nodes.
|
||||
Steps to expand node:
|
||||
1. Select final action node with virtual loss and collect them in to a minibatch.
|
||||
(i.e. root->action->state->action...->action)
|
||||
2. Remove the virtual loss
|
||||
3. Evaluate the whole minibatch using evaluator
|
||||
4. Expand new nodes and perform back propogation.
|
||||
"""
|
||||
## minibatch with virtual loss
|
||||
nodes = []
|
||||
new_actions = []
|
||||
next_states = []
|
||||
|
||||
for i in range(self.batch_size):
|
||||
node, new_action = self.root.selection(self.simulator)
|
||||
nodes.append(node)
|
||||
new_actions.append(new_action)
|
||||
next_states.append(node.children[new_action].next_state)
|
||||
|
||||
for node in nodes:
|
||||
node.remove_virtual_loss()
|
||||
|
||||
assert(np.sum(self.root.virtual_loss > 0) == 0)
|
||||
#### compute value in batch manner unless the evaluator do not support it
|
||||
try:
|
||||
priors, values = self.evaluator(next_states)
|
||||
except:
|
||||
priors = []
|
||||
values = []
|
||||
for i in range(self.batch_size):
|
||||
if next_states[i] is not None:
|
||||
prior, value = self.evaluator(next_states[i])
|
||||
priors.append(prior)
|
||||
values.append(value)
|
||||
else:
|
||||
priors.append(0.)
|
||||
values.append(0.)
|
||||
|
||||
#### for now next_state == origin_state
|
||||
#### may have problem here. What if we reached the same next_state with same parent and action pair
|
||||
for i in range(self.batch_size):
|
||||
nodes[i].children[new_actions[i]].expansion(new_actions[i],
|
||||
next_states[i],
|
||||
self.action_num,
|
||||
priors[i],
|
||||
nodes[i].inverse)
|
||||
|
||||
if self.inverse:
|
||||
for i in range(self.batch_size):
|
||||
nodes[i].children[new_actions[i]].backpropagation(-values[i] + 0.)
|
||||
else:
|
||||
for i in range(self.batch_size):
|
||||
nodes[i].children[new_actions[i]].backpropagation(values[i] + 0.)
|
||||
|
||||
|
||||
##### TODO
|
||||
class TSNodeVirtualLoss(MCTSNodeVirtualLoss):
|
||||
def __init__(self, parent, action, state, action_num, prior, method="Gaussian", inverse=False):
|
||||
super(TSNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||
if method == "Beta":
|
||||
self.alpha = np.ones([action_num])
|
||||
self.beta = np.ones([action_num])
|
||||
if method == "Gaussian":
|
||||
self.mu = np.zeros([action_num])
|
||||
self.sigma = np.zeros([action_num])
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcts_virtual_loss = MCTSNodeVirtualLoss(None, None, 10, 1, 'UCT')
|
||||
55
tianshou/core/mcts/mcts_virtual_loss_test.py
Normal file
55
tianshou/core/mcts/mcts_virtual_loss_test.py
Normal file
@ -0,0 +1,55 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim:fenc=utf-8
|
||||
# $File: mcts_virtual_loss_test.py
|
||||
# $Date: Sat Dec 23 02:2139 2017 +0800
|
||||
# Original file: mcts_test.py
|
||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||
#
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
from .mcts_virtual_loss import MCTSVirtualLoss
|
||||
from .evaluator import rollout_policy
|
||||
|
||||
|
||||
class TestEnv:
|
||||
def __init__(self, max_step=5):
|
||||
self.max_step = max_step
|
||||
self.reward = {i: np.random.uniform() for i in range(2 ** max_step)}
|
||||
# self.reward = {0:1, 1:0}
|
||||
self.best = max(self.reward.items(), key=lambda x: x[1])
|
||||
print(self.reward)
|
||||
# print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1]))
|
||||
|
||||
def simulate_is_valid(self, state, act):
|
||||
return True
|
||||
|
||||
def step_forward(self, state, action):
|
||||
if action != 0 and action != 1:
|
||||
raise ValueError("Action must be 0 or 1! Your action is {}".format(action))
|
||||
if state[0] >= 2 ** state[1] or state[1] > self.max_step:
|
||||
raise ValueError("Invalid State! Your state is {}".format(state))
|
||||
# print("Operate action {} at state {}, timestep {}".format(action, state[0], state[1]))
|
||||
if state[1] == self.max_step:
|
||||
new_state = None
|
||||
reward = 0
|
||||
else:
|
||||
num = state[0] + 2 ** state[1] * action
|
||||
step = state[1] + 1
|
||||
new_state = [num, step]
|
||||
if step == self.max_step:
|
||||
reward = int(np.random.uniform() < self.reward[num])
|
||||
else:
|
||||
reward = 0.
|
||||
return new_state, reward
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = TestEnv(2)
|
||||
rollout = rollout_policy(env, 2)
|
||||
evaluator = lambda state: rollout(state)
|
||||
mcts_virtual_loss = MCTSVirtualLoss(env, evaluator, [0, 0], 2, batch_size = 10)
|
||||
for i in range(10):
|
||||
mcts_virtual_loss.do_search(max_step = 100)
|
||||
|
||||
29
tianshou/core/mcts/unit_test/Evaluator.py
Normal file
29
tianshou/core/mcts/unit_test/Evaluator.py
Normal file
@ -0,0 +1,29 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class evaluator(object):
|
||||
def __init__(self, env, action_num):
|
||||
self.env = env
|
||||
self.action_num = action_num
|
||||
|
||||
def __call__(self, state):
|
||||
raise NotImplementedError("Need to implement the evaluator")
|
||||
|
||||
|
||||
class rollout_policy(evaluator):
|
||||
def __init__(self, env, action_num):
|
||||
super(rollout_policy, self).__init__(env, action_num)
|
||||
self.is_terminated = False
|
||||
|
||||
def __call__(self, state):
|
||||
# TODO: prior for rollout policy
|
||||
total_reward = 0.
|
||||
color = state[1]
|
||||
action = np.random.randint(0, self.action_num)
|
||||
state, reward = self.env.simulate_step_forward(state, action)
|
||||
total_reward += reward
|
||||
while state is not None:
|
||||
action = np.random.randint(0, self.action_num)
|
||||
state, reward = self.env.simulate_step_forward(state, action)
|
||||
total_reward += reward
|
||||
return np.ones([self.action_num])/self.action_num, total_reward * color
|
||||
21
tianshou/core/mcts/unit_test/README.md
Normal file
21
tianshou/core/mcts/unit_test/README.md
Normal file
@ -0,0 +1,21 @@
|
||||
# Unit Test
|
||||
|
||||
This is a two-player zero-sum perfect information extensive game. Player 1 and player 2 iteratively choose actions. At every iteration, player 1 players first and player 2 follows. Both players have choices 0 or 1.
|
||||
|
||||
The number of iterations is given as a fixed number. After one game finished, the game counts the number of 0s and 1s that are choosen. If the number of 1 is more than that of 0, player 1 gets 1 and player 2 gets -1. If the number of 1 is less than that of 0, player 1 gets -1 and player 2 gets 1. Otherwise, they both get 0.
|
||||
|
||||
## Files
|
||||
|
||||
+ game.py: run this file to play the game.
|
||||
+ agent.py: a class for players. MCTS is used here.
|
||||
+ ZOgame.py: the game environment.
|
||||
+ mcts.py: MCTS method.
|
||||
+ Evaluator: evaluator for MCTS. Rollout policy is also here.
|
||||
|
||||
## Parameters
|
||||
|
||||
Three paramters are given in game.py.
|
||||
|
||||
+ size: the number of iterations
|
||||
+ searching_step: the number of searching times of MCTS for one step
|
||||
+ temp: the temporature paramter used to tradeoff exploitation and exploration
|
||||
97
tianshou/core/mcts/unit_test/ZOGame.py
Normal file
97
tianshou/core/mcts/unit_test/ZOGame.py
Normal file
@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env python
|
||||
import numpy as np
|
||||
import copy
|
||||
|
||||
|
||||
class ZOTree:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
self.depth = self.size * 2
|
||||
|
||||
def simulate_step_forward(self, state, action):
|
||||
self._check_state(state)
|
||||
seq, color = copy.deepcopy(state)
|
||||
if len(seq) == self.depth:
|
||||
winner = self.executor_get_reward(state)
|
||||
return None, color * winner
|
||||
else:
|
||||
seq.append(int(action))
|
||||
return [seq, 0 - color], 0
|
||||
|
||||
def simulate_hashable_conversion(self, state):
|
||||
self._check_state(state)
|
||||
# since go is MDP, we only need the last board for hashing
|
||||
return tuple(state[0])
|
||||
|
||||
def executor_get_reward(self, state):
|
||||
self._check_state(state)
|
||||
seq = np.array(state[0], dtype='int16')
|
||||
length = len(seq)
|
||||
if length != self.depth:
|
||||
raise ValueError("The game is not terminated!")
|
||||
result = np.sum(seq)
|
||||
if result > self.size:
|
||||
winner = 1
|
||||
elif result < self.size:
|
||||
winner = -1
|
||||
else:
|
||||
winner = 0
|
||||
return winner
|
||||
|
||||
def executor_do_move(self, state, action):
|
||||
self._check_state(state)
|
||||
seq, color = state
|
||||
if len(seq) == self.depth:
|
||||
return False
|
||||
else:
|
||||
seq.append(int(action))
|
||||
if len(seq) == self.depth:
|
||||
return False
|
||||
return True
|
||||
|
||||
def v_value(self, state):
|
||||
self._check_state(state)
|
||||
seq, color = state
|
||||
ones = 0
|
||||
zeros = 0
|
||||
for i in range(len(seq)):
|
||||
if seq[i] == 0:
|
||||
zeros += 1
|
||||
if seq[i] == 1:
|
||||
ones += 1
|
||||
choosen_result = ones - zeros
|
||||
if color == 1:
|
||||
if choosen_result > 0:
|
||||
return 1
|
||||
elif choosen_result < 0:
|
||||
return -1
|
||||
else:
|
||||
return 0
|
||||
elif color == -1:
|
||||
if choosen_result > 1:
|
||||
return 1
|
||||
elif choosen_result < 1:
|
||||
return -1
|
||||
else:
|
||||
return 0
|
||||
else:
|
||||
raise ValueError("Wrong color")
|
||||
|
||||
def _check_state(self, state):
|
||||
seq, color = state
|
||||
if color == 1:
|
||||
if len(seq) % 2:
|
||||
raise ValueError("Color is 1 but the length of seq is odd!")
|
||||
elif color == -1:
|
||||
if not len(seq) % 2:
|
||||
raise ValueError("Color is -1 but the length of seq is even!")
|
||||
else:
|
||||
raise ValueError("Wrong color!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
size = 2
|
||||
game = ZOTree(size)
|
||||
seq = [1, 0, 1, 1]
|
||||
result = game.executor_do_move([seq, 1], 1)
|
||||
print(result)
|
||||
print(seq)
|
||||
28
tianshou/core/mcts/unit_test/agent.py
Normal file
28
tianshou/core/mcts/unit_test/agent.py
Normal file
@ -0,0 +1,28 @@
|
||||
#!/usr/bin/env python
|
||||
import numpy as np
|
||||
import ZOGame
|
||||
import Evaluator
|
||||
from mcts import MCTS
|
||||
|
||||
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, size, color, searching_step, temp):
|
||||
self.size = size
|
||||
self.color = color
|
||||
self.searching_step = searching_step
|
||||
self.temp = temp
|
||||
self.simulator = ZOGame.ZOTree(self.size)
|
||||
self.evaluator = Evaluator.rollout_policy(self.simulator, 2)
|
||||
|
||||
def gen_move(self, seq):
|
||||
if len(seq) >= 2 * self.size:
|
||||
raise ValueError("Game is terminated.")
|
||||
mcts = MCTS(self.simulator, self.evaluator, [seq, self.color], 2, inverse=True)
|
||||
mcts.search(max_step=self.searching_step)
|
||||
N = mcts.root.N
|
||||
N = np.power(N, 1.0 / self.temp)
|
||||
prob = N / np.sum(N)
|
||||
action = int(np.random.binomial(1, prob[1]))
|
||||
return action
|
||||
39
tianshou/core/mcts/unit_test/game.py
Normal file
39
tianshou/core/mcts/unit_test/game.py
Normal file
@ -0,0 +1,39 @@
|
||||
import ZOGame
|
||||
import agent
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
size = 10
|
||||
seaching_step = 100
|
||||
temp = 1
|
||||
print("Our game has 2 players.")
|
||||
print("Player 1 has color 1 and plays first. Player 2 has color -1 and plays following player 1.")
|
||||
print("Both player choose 1 or 0 for an action.")
|
||||
print("This game has {} iterations".format(size))
|
||||
print("If the final sequence has more 1 that 0, player 1 wins.")
|
||||
print("If the final sequence has less 1 that 0, player 2 wins.")
|
||||
print("Otherwise, both players get 0.\n")
|
||||
game = ZOGame.ZOTree(size)
|
||||
player1 = agent.Agent(size, 1, seaching_step, temp)
|
||||
player2 = agent.Agent(size, -1, seaching_step, temp)
|
||||
|
||||
seq = []
|
||||
print("Sequence is {}\n".format(seq))
|
||||
while True:
|
||||
action1 = player1.gen_move(seq)
|
||||
print("action1 is {}".format(action1))
|
||||
result = game.executor_do_move([seq, 1], action1)
|
||||
print("Sequence is {}\n".format(seq))
|
||||
if not result:
|
||||
winner = game.executor_get_reward([seq, 1])
|
||||
break
|
||||
action2 = player2.gen_move(seq)
|
||||
print("action2 is {}".format(action2))
|
||||
result = game.executor_do_move([seq, -1], action2)
|
||||
print("Sequence is {}\n".format(seq))
|
||||
if not result:
|
||||
winner = game.executor_get_reward([seq, 1])
|
||||
break
|
||||
|
||||
print("The choice sequence is {}".format(seq))
|
||||
print("The game result is {}".format(winner))
|
||||
198
tianshou/core/mcts/unit_test/mcts.py
Normal file
198
tianshou/core/mcts/unit_test/mcts.py
Normal file
@ -0,0 +1,198 @@
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
|
||||
c_puct = 5
|
||||
|
||||
class MCTSNode(object):
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||
self.parent = parent
|
||||
self.action = action
|
||||
self.children = {}
|
||||
self.state = state
|
||||
self.action_num = action_num
|
||||
self.prior = np.array(prior).reshape(-1)
|
||||
self.inverse = inverse
|
||||
|
||||
def selection(self, simulator):
|
||||
raise NotImplementedError("Need to implement function selection")
|
||||
|
||||
def backpropagation(self, action):
|
||||
raise NotImplementedError("Need to implement function backpropagation")
|
||||
|
||||
def valid_mask(self, simulator):
|
||||
pass
|
||||
|
||||
class UCTNode(MCTSNode):
|
||||
def __init__(self, parent, action, state, action_num, prior, mcts, inverse=False):
|
||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||
self.Q = np.zeros([action_num])
|
||||
self.W = np.zeros([action_num])
|
||||
self.N = np.zeros([action_num])
|
||||
self.c_puct = c_puct
|
||||
self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
||||
self.mask = None
|
||||
self.elapse_time = 0
|
||||
self.mcts = mcts
|
||||
|
||||
def selection(self, simulator):
|
||||
head = time.time()
|
||||
self.valid_mask(simulator)
|
||||
self.mcts.valid_mask_time += time.time() - head
|
||||
action = np.argmax(self.ucb)
|
||||
if action in self.children.keys():
|
||||
self.mcts.state_selection_time += time.time() - head
|
||||
return self.children[action].selection(simulator)
|
||||
else:
|
||||
self.children[action] = ActionNode(self, action, mcts=self.mcts)
|
||||
self.mcts.state_selection_time += time.time() - head
|
||||
return self.children[action].selection(simulator)
|
||||
|
||||
def backpropagation(self, action):
|
||||
action = int(action)
|
||||
self.N[action] += 1
|
||||
self.W[action] += self.children[action].reward
|
||||
for i in range(self.action_num):
|
||||
if self.N[i] != 0:
|
||||
self.Q[i] = (self.W[i] + 0.) / self.N[i]
|
||||
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1.)
|
||||
if self.parent is not None:
|
||||
if self.inverse:
|
||||
self.parent.backpropagation(-self.children[action].reward)
|
||||
else:
|
||||
self.parent.backpropagation(self.children[action].reward)
|
||||
|
||||
def valid_mask(self, simulator):
|
||||
# let all invalid actions be illegal in mcts
|
||||
if not hasattr(simulator, 'simulate_get_mask'):
|
||||
pass
|
||||
else:
|
||||
if self.mask is None:
|
||||
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
|
||||
self.ucb[self.mask] = -float("Inf")
|
||||
|
||||
# Code reserved for Thompson Sampling
|
||||
class TSNode(MCTSNode):
|
||||
def __init__(self, parent, action, state, action_num, prior, method="Gaussian", inverse=False):
|
||||
super(TSNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||
if method == "Beta":
|
||||
self.alpha = np.ones([action_num])
|
||||
self.beta = np.ones([action_num])
|
||||
if method == "Gaussian":
|
||||
self.mu = np.zeros([action_num])
|
||||
self.sigma = np.zeros([action_num])
|
||||
|
||||
|
||||
class ActionNode(object):
|
||||
def __init__(self, parent, action, mcts):
|
||||
self.parent = parent
|
||||
self.action = action
|
||||
self.children = {}
|
||||
self.next_state = None
|
||||
self.next_state_hashable = None
|
||||
self.state_type = None
|
||||
self.reward = 0
|
||||
self.mcts = mcts
|
||||
|
||||
def selection(self, simulator):
|
||||
head = time.time()
|
||||
self.next_state, self.reward = simulator.simulate_step_forward(self.parent.state, self.action)
|
||||
self.mcts.simulate_sf_time += time.time() - head
|
||||
if self.next_state is None: # next_state is None means that self.parent.state is the terminate state
|
||||
self.mcts.action_selection_time += time.time() - head
|
||||
return self
|
||||
head = time.time()
|
||||
self.next_state_hashable = simulator.simulate_hashable_conversion(self.next_state)
|
||||
self.mcts.hash_time += time.time() - head
|
||||
if self.next_state_hashable in self.children.keys(): # next state has already visited before
|
||||
self.mcts.action_selection_time += time.time() - head
|
||||
return self.children[self.next_state_hashable].selection(simulator)
|
||||
else: # next state is a new state never seen before
|
||||
self.mcts.action_selection_time += time.time() - head
|
||||
return self
|
||||
|
||||
def expansion(self, prior, action_num):
|
||||
self.children[self.next_state_hashable] = UCTNode(self, self.action, self.next_state, action_num, prior,
|
||||
mcts=self.mcts, inverse=self.parent.inverse)
|
||||
|
||||
def backpropagation(self, value):
|
||||
self.reward += value
|
||||
self.parent.backpropagation(self.action)
|
||||
|
||||
class MCTS(object):
|
||||
def __init__(self, simulator, evaluator, start_state, action_num, method="UCT",
|
||||
role="unknown", debug=False, inverse=False):
|
||||
self.simulator = simulator
|
||||
self.evaluator = evaluator
|
||||
self.role = role
|
||||
self.debug = debug
|
||||
prior, _ = self.evaluator(start_state)
|
||||
self.action_num = action_num
|
||||
if method == "":
|
||||
self.root = start_state
|
||||
if method == "UCT":
|
||||
self.root = UCTNode(None, None, start_state, action_num, prior, mcts=self, inverse=inverse)
|
||||
if method == "TS":
|
||||
self.root = TSNode(None, None, start_state, action_num, prior, inverse=inverse)
|
||||
self.inverse = inverse
|
||||
|
||||
# time spend on each step
|
||||
self.selection_time = 0
|
||||
self.expansion_time = 0
|
||||
self.backpropagation_time = 0
|
||||
self.action_selection_time = 0
|
||||
self.state_selection_time = 0
|
||||
self.simulate_sf_time = 0
|
||||
self.valid_mask_time = 0
|
||||
self.hash_time = 0
|
||||
|
||||
def search(self, max_step=None, max_time=None):
|
||||
step = 0
|
||||
start_time = time.time()
|
||||
if max_step is None:
|
||||
max_step = int("Inf")
|
||||
if max_time is None:
|
||||
max_time = float("Inf")
|
||||
if max_step is None and max_time is None:
|
||||
raise ValueError("Need a stop criteria!")
|
||||
|
||||
while step < max_step and time.time() - start_time < max_step:
|
||||
sel_time, exp_time, back_time = self._expand()
|
||||
self.selection_time += sel_time
|
||||
self.expansion_time += exp_time
|
||||
self.backpropagation_time += back_time
|
||||
step += 1
|
||||
if self.debug:
|
||||
file = open("mcts_profiling.log", "a")
|
||||
file.write("[" + str(self.role) + "]"
|
||||
+ " sel " + '%.3f' % self.selection_time + " "
|
||||
+ " sel_sta " + '%.3f' % self.state_selection_time + " "
|
||||
+ " valid " + '%.3f' % self.valid_mask_time + " "
|
||||
+ " sel_act " + '%.3f' % self.action_selection_time + " "
|
||||
+ " hash " + '%.3f' % self.hash_time + " "
|
||||
+ " step forward " + '%.3f' % self.simulate_sf_time + " "
|
||||
+ " expansion " + '%.3f' % self.expansion_time + " "
|
||||
+ " backprop " + '%.3f' % self.backpropagation_time + " "
|
||||
+ "\n")
|
||||
file.close()
|
||||
|
||||
def _expand(self):
|
||||
t0 = time.time()
|
||||
next_action = self.root.selection(self.simulator)
|
||||
t1 = time.time()
|
||||
# next_action.next_state is None means the parent state node of next_action is a terminate node
|
||||
if next_action.next_state is not None:
|
||||
prior, value = self.evaluator(next_action.next_state)
|
||||
next_action.expansion(prior, self.action_num)
|
||||
else:
|
||||
value = 0.
|
||||
t2 = time.time()
|
||||
if self.inverse:
|
||||
next_action.backpropagation(-value + 0.)
|
||||
else:
|
||||
next_action.backpropagation(value + 0.)
|
||||
t3 = time.time()
|
||||
return t1 - t0, t2 - t1, t3 - t2
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
21
tianshou/core/mcts/utils.py
Normal file
21
tianshou/core/mcts/utils.py
Normal file
@ -0,0 +1,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim:fenc=utf-8
|
||||
# $File: utils.py
|
||||
# $Date: Sat Dec 23 02:0854 2017 +0800
|
||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||
#
|
||||
|
||||
def list2tuple(list):
|
||||
try:
|
||||
return tuple(list2tuple(sub) for sub in list)
|
||||
except TypeError:
|
||||
return list
|
||||
|
||||
|
||||
def tuple2list(tuple):
|
||||
try:
|
||||
return list(tuple2list(sub) for sub in tuple)
|
||||
except TypeError:
|
||||
return tuple
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user