AlphaGo update

This commit is contained in:
rtz19970824 2017-11-26 13:36:52 +08:00
parent e727ce4d9b
commit ca0021083f
21 changed files with 625 additions and 422 deletions

View File

@ -12,6 +12,7 @@ import time
# os.environ["CUDA_VISIBLE_DEVICES"] = "1" # os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
def residual_block(input, is_training): def residual_block(input, is_training):
normalizer_params = {'is_training': is_training, normalizer_params = {'is_training': is_training,
'updates_collections': tf.GraphKeys.UPDATE_OPS} 'updates_collections': tf.GraphKeys.UPDATE_OPS}
@ -129,52 +130,68 @@ def train():
saver.save(sess, result_path + save_path) saver.save(sess, result_path + save_path)
del data, boards, wins, ps del data, boards, wins, ps
def forward(call_number):
#checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints"
checkpoint_path = "/home/jialian/stuGo/tianshou/stuGo/checkpoints/"
board_file = np.genfromtxt("/home/jialian/stuGo/tianshou/leela-zero/src/mcts_nn_files/board_" + call_number, dtype='str');
human_board = np.zeros((17, 19, 19))
#TODO : is it ok to ignore the last channel? # def forward(call_number):
for i in range(17): # # checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints"
human_board[i] = np.array(list(board_file[i])).reshape(19, 19) # checkpoint_path = "/home/jialian/stuGo/tianshou/stuGo/checkpoints/"
#print("============================") # board_file = np.genfromtxt("/home/jialian/stuGo/tianshou/leela-zero/src/mcts_nn_files/board_" + call_number,
#print("human board sum : " + str(np.sum(human_board[-1]))) # dtype='str');
#print("============================") # human_board = np.zeros((17, 19, 19))
#print(human_board) #
#print("============================") # # TODO : is it ok to ignore the last channel?
#rint(human_board) # for i in range(17):
feed_board = human_board.transpose(1, 2, 0).reshape(1, 19, 19, 17) # human_board[i] = np.array(list(board_file[i])).reshape(19, 19)
#print(feed_board[:,:,:,-1]) # # print("============================")
#print(feed_board.shape) # # print("human board sum : " + str(np.sum(human_board[-1])))
# # print("============================")
# # print(human_board)
# # print("============================")
# # rint(human_board)
# feed_board = human_board.transpose(1, 2, 0).reshape(1, 19, 19, 17)
# # print(feed_board[:,:,:,-1])
# # print(feed_board.shape)
#
# # npz_board = np.load("/home/yama/rl/tianshou/AlphaGo/data/7f83928932f64a79bc1efdea268698ae.npz")
# # print(npz_board["boards"].shape)
# # feed_board = npz_board["boards"][10].reshape(-1, 19, 19, 17)
# ##print(feed_board)
# # show_board = feed_board[0].transpose(2, 0, 1)
# # print("board shape : ", show_board.shape)
# # print(show_board)
#
# itflag = False
# with multi_gpu.create_session() as sess:
# sess.run(tf.global_variables_initializer())
# ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
# if ckpt_file is not None:
# # print('Restoring model from {}...'.format(ckpt_file))
# saver.restore(sess, ckpt_file)
# else:
# raise ValueError("No model loaded")
# res = sess.run([tf.nn.softmax(p), v], feed_dict={x: feed_board, is_training: itflag})
# # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
# # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
# # print(np.argmax(res[0]))
# np.savetxt(sys.stdout, res[0][0], fmt="%.6f", newline=" ")
# np.savetxt(sys.stdout, res[1][0], fmt="%.6f", newline=" ")
# pv_file = "/home/jialian/stuGotianshou/leela-zero/src/mcts_nn_files/policy_value"
# np.savetxt(pv_file, np.concatenate((res[0][0], res[1][0])), fmt="%.6f", newline=" ")
# # np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ")
# return res
#npz_board = np.load("/home/yama/rl/tianshou/AlphaGo/data/7f83928932f64a79bc1efdea268698ae.npz") def forward(state):
#print(npz_board["boards"].shape) checkpoint_path = "/home/tongzheng/tianshou/AlphaGo/checkpoints/"
#feed_board = npz_board["boards"][10].reshape(-1, 19, 19, 17)
##print(feed_board)
#show_board = feed_board[0].transpose(2, 0, 1)
#print("board shape : ", show_board.shape)
#print(show_board)
itflag = False
with multi_gpu.create_session() as sess: with multi_gpu.create_session() as sess:
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
ckpt_file = tf.train.latest_checkpoint(checkpoint_path) ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
if ckpt_file is not None: if ckpt_file is not None:
#print('Restoring model from {}...'.format(ckpt_file)) print('Restoring model from {}...'.format(ckpt_file))
saver.restore(sess, ckpt_file) saver.restore(sess, ckpt_file)
else: else:
raise ValueError("No model loaded") raise ValueError("No model loaded")
res = sess.run([tf.nn.softmax(p),v], feed_dict={x:feed_board, is_training:itflag}) prior, value = sess.run([tf.nn.softmax(p), v], feed_dict={x: state, is_training: False})
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False}) return prior, value
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
#print(np.argmax(res[0]))
np.savetxt(sys.stdout, res[0][0], fmt="%.6f", newline=" ")
np.savetxt(sys.stdout, res[1][0], fmt="%.6f", newline=" ")
pv_file = "/home/jialian/stuGotianshou/leela-zero/src/mcts_nn_files/policy_value"
np.savetxt(pv_file, np.concatenate((res[0][0], res[1][0])), fmt="%.6f", newline=" ")
#np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ")
return res
if __name__ == '__main__': if __name__ == '__main__':
np.set_printoptions(threshold='nan') np.set_printoptions(threshold='nan')

View File

@ -132,6 +132,8 @@ def train():
del save_path del save_path
del data, boards, wins, ps, batch_num, index del data, boards, wins, ps, batch_num, index
gc.collect() gc.collect()
def forward(board): def forward(board):
result_path = "./checkpoints" result_path = "./checkpoints"
itflag = False itflag = False

View File

@ -8,6 +8,7 @@
from game import Game from game import Game
import utils import utils
class GTPEngine(): class GTPEngine():
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.size = 19 self.size = 19
@ -27,7 +28,6 @@ class GTPEngine():
except: except:
self._version = 2 self._version = 2
self.disconnect = False self.disconnect = False
self.known_commands = [ self.known_commands = [
@ -42,9 +42,6 @@ class GTPEngine():
x, y = vertex x, y = vertex
return "{}{}".format("ABCDEFGHJKLMNOPQRSTYVWYZ"[x - 1], y) return "{}{}".format("ABCDEFGHJKLMNOPQRSTYVWYZ"[x - 1], y)
def _vertex_string2point(self, s): def _vertex_string2point(self, s):
if s is None: if s is None:
return False return False
@ -62,7 +59,6 @@ class GTPEngine():
return False return False
return (x, y) return (x, y)
def _parse_color(self, color): def _parse_color(self, color):
if color.lower() in ["b", "black"]: if color.lower() in ["b", "black"]:
color = utils.BLACK color = utils.BLACK
@ -72,7 +68,6 @@ class GTPEngine():
color = None color = None
return color return color
def _parse_move(self, move_string): def _parse_move(self, move_string):
color, move = move_string.split(" ", 1) color, move = move_string.split(" ", 1)
color = self._parse_color(color) color = self._parse_color(color)
@ -84,8 +79,6 @@ class GTPEngine():
else: else:
return False return False
def _parse_res(self, res, id_=None, success=True): def _parse_res(self, res, id_=None, success=True):
if success: if success:
if id_: if id_:
@ -98,7 +91,6 @@ class GTPEngine():
else: else:
return '? {}\n\n'.format(res) return '? {}\n\n'.format(res)
def _parse_cmd(self, message): def _parse_cmd(self, message):
try: try:
m = message.strip().split(" ", 1) m = message.strip().split(" ", 1)
@ -130,8 +122,6 @@ class GTPEngine():
else: else:
return self._parse_res("unknown command", id_, False) return self._parse_res("unknown command", id_, False)
def cmd_protocol_version(self, args, **kwargs): def cmd_protocol_version(self, args, **kwargs):
return 2, True return 2, True
@ -172,7 +162,6 @@ class GTPEngine():
except ValueError: except ValueError:
raise ValueError("syntax error") raise ValueError("syntax error")
def cmd_play(self, args, **kwargs): def cmd_play(self, args, **kwargs):
move = self._parse_move(args) move = self._parse_move(args)
if move: if move:
@ -191,7 +180,3 @@ class GTPEngine():
return self._vertex_point2string(move), True return self._vertex_point2string(move), True
else: else:
return 'unknown player', False return 'unknown player', False

69
AlphaGo/game.py Normal file
View File

@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
# $File: game.py
# $Date: Fri Nov 17 15:0745 2017 +0800
# $Author: renyong15 © <mails.tsinghua.edu.cn>
#
import numpy as np
import utils
import Network
from strategy import strategy
from collections import deque
class Game:
def __init__(self, size=19, komi=6.5):
self.size = size
self.komi = 6.5
self.board = [utils.EMPTY] * (self.size * self.size)
self.strategy = strategy(Network.forward)
self.history = deque(maxlen=8)
for i in range(8):
self.history.append(self.board)
def _flatten(self, vertex):
x, y = vertex
return (x - 1) * self.size + (y - 1)
def clear(self):
self.board = [utils.EMPTY] * (self.size * self.size)
def set_size(self, n):
self.size = n
self.clear()
def set_komi(self, k):
self.komi = k
def do_move(self, color, vertex):
if vertex == utils.PASS:
return True
id_ = self._flatten(vertex)
if self.board[id_] == utils.EMPTY:
self.board[id_] = color
self.history.append(self.board)
return True
else:
return False
def step_forward(self, state, action):
if state[0, 0, 0, -1] == 1:
color = 1
else:
color = -1
if action == 361:
vertex = (0, 0)
else:
vertex = (action / 19 + 1, action % 19)
self.do_move(color, vertex)
new_state = np.concatenate([state[:, :, :, 1:8], self.board == 1, state[:, :, :, 9:16], 1 - state[:, :, :, -1]],
axis=3)
return new_state, 0
def gen_move(self, color):
move = self.strategy.gen_move(self.history, color)
return move
# return utils.PASS

View File

@ -3,7 +3,6 @@ import go
import utils import utils
def translate_gtp_colors(gtp_color): def translate_gtp_colors(gtp_color):
if gtp_color == gtp.BLACK: if gtp_color == gtp.BLACK:
return go.BLACK return go.BLACK
@ -12,6 +11,7 @@ def translate_gtp_colors(gtp_color):
else: else:
return go.EMPTY return go.EMPTY
class GtpInterface(object): class GtpInterface(object):
def __init__(self): def __init__(self):
self.size = 9 self.size = 9
@ -68,19 +68,3 @@ class GtpInterface(object):
def suggest_move(self, position): def suggest_move(self, position):
raise NotImplementedError raise NotImplementedError
def make_gtp_instance(strategy_name, read_file):
n = PolicyNetwork(use_cpu=True)
n.initialize_variables(read_file)
if strategy_name == 'random':
instance = RandomPlayer()
elif strategy_name == 'policy':
instance = GreedyPolicyPlayer(n)
elif strategy_name == 'randompolicy':
instance = RandomPolicyPlayer(n)
elif strategy_name == 'mcts':
instance = MCTSPlayer(n)
else:
return None
gtp_engine = gtp.Engine(instance)
return gtp_engine

78
AlphaGo/strategy.py Normal file
View File

@ -0,0 +1,78 @@
import numpy as np
import utils
from collections import deque
from tianshou.core.mcts.mcts import MCTS
class GoEnv:
def __init__(self, size=19, komi=6.5):
self.size = size
self.komi = 6.5
self.board = [utils.EMPTY] * (self.size * self.size)
self.history = deque(maxlen=8)
def _flatten(self, vertex):
x, y = vertex
return (x - 1) * self.size + (y - 1)
def do_move(self, color, vertex):
if vertex == utils.PASS:
return True
id_ = self._flatten(vertex)
if self.board[id_] == utils.EMPTY:
self.board[id_] = color
self.history.append(self.board)
return True
else:
return False
def step_forward(self, state, action):
# print(state)
if state[0, 0, 0, -1] == 1:
color = 1
else:
color = -1
if action == 361:
vertex = (0, 0)
else:
vertex = (action / 19 + 1, action % 19)
self.do_move(color, vertex)
new_state = np.concatenate(
[state[:, :, :, 1:8], (np.array(self.board) == 1).reshape(1, 19, 19, 1),
state[:, :, :, 9:16], (np.array(self.board) == -1).reshape(1, 19, 19, 1),
np.array(1 - state[:, :, :, -1]).reshape(1, 19, 19, 1)],
axis=3)
return new_state, 0
class strategy(object):
def __init__(self, evaluator):
self.simulator = GoEnv()
self.evaluator = evaluator
def data_process(self, history, color):
state = np.zeros([1, 19, 19, 17])
for i in range(8):
state[0, :, :, i] = history[i] == 1
state[0, :, :, i + 8] = history[i] == -1
if color == 1:
state[0, :, :, 16] = np.ones([19, 19])
if color == -1:
state[0, :, :, 16] = np.zeros([19, 19])
return state
def gen_move(self, history, color):
self.simulator.history = history
self.simulator.board = history[-1]
state = self.data_process(history, color)
prior = self.evaluator(state)[0]
mcts = MCTS(self.simulator, self.evaluator, state, 362, prior, inverse=True, max_step=20)
temp = 1
p = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
choice = np.random.choice(362, 1, p=p).tolist()[0]
if choice == 361:
move = (0, 0)
else:
move = (choice / 19 + 1, choice % 19 + 1)
return move

View File

@ -8,8 +8,6 @@
from game import Game from game import Game
from engine import GTPEngine from engine import GTPEngine
g = Game() g = Game()
e = GTPEngine(game_obj=g) e = GTPEngine(game_obj=g)
res = e.run_cmd('1 protocol_version') res = e.run_cmd('1 protocol_version')
@ -37,4 +35,5 @@ print(res)
res = e.run_cmd('8 genmove BLACK') res = e.run_cmd('8 genmove BLACK')
print(res) print(res)
res = e.run_cmd('9 genmove WHITE')
print(res)

119
AlphaGo/utils.py Normal file
View File

@ -0,0 +1,119 @@
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
# $File: utils.py
# $Date: Fri Nov 17 10:2407 2017 +0800
# $Author: renyong15 © <mails.tsinghua.edu.cn>
#
WHITE = -1
BLACK = +1
EMPTY = 0
PASS = (0, 0)
RESIGN = "resign"
from collections import defaultdict
import functools
import itertools
import operator
import random
import re
import time
import gtp
import go
KGS_COLUMNS = 'ABCDEFGHJKLMNOPQRST'
SGF_COLUMNS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
def parse_sgf_to_flat(sgf):
return flatten_coords(parse_sgf_coords(sgf))
def flatten_coords(c):
return go.N * c[0] + c[1]
def unflatten_coords(f):
return divmod(f, go.N)
def parse_sgf_coords(s):
'Interprets coords. aa is top left corner; sa is top right corner'
if s is None or s == '':
return None
return SGF_COLUMNS.index(s[1]), SGF_COLUMNS.index(s[0])
def unparse_sgf_coords(c):
if c is None:
return ''
return SGF_COLUMNS[c[1]] + SGF_COLUMNS[c[0]]
def parse_kgs_coords(s):
'Interprets coords. A1 is bottom left; A9 is top left.'
if s == 'pass':
return None
s = s.upper()
col = KGS_COLUMNS.index(s[0])
row_from_bottom = int(s[1:]) - 1
return go.N - row_from_bottom - 1, col
def parse_pygtp_coords(vertex):
'Interprets coords. (1, 1) is bottom left; (1, 9) is top left.'
if vertex in (gtp.PASS, gtp.RESIGN):
return None
return go.N - vertex[1], vertex[0] - 1
def unparse_pygtp_coords(c):
if c is None:
return gtp.PASS
return c[1] + 1, go.N - c[0]
def parse_game_result(result):
if re.match(r'[bB]\+', result):
return go.BLACK
elif re.match(r'[wW]\+', result):
return go.WHITE
else:
return None
def product(numbers):
return functools.reduce(operator.mul, numbers)
def take_n(n, iterable):
return list(itertools.islice(iterable, n))
def iter_chunks(chunk_size, iterator):
while True:
next_chunk = take_n(chunk_size, iterator)
# If len(iterable) % chunk_size == 0, don't return an empty chunk.
if next_chunk:
yield next_chunk
else:
break
def shuffler(iterator, pool_size=10**5, refill_threshold=0.9):
yields_between_refills = round(pool_size * (1 - refill_threshold))
# initialize pool; this step may or may not exhaust the iterator.
pool = take_n(pool_size, iterator)
while True:
random.shuffle(pool)
for i in range(yields_between_refills):
yield pool.pop()
next_batch = take_n(yields_between_refills, iterator)
if not next_batch:
break
pool.extend(next_batch)
# finish consuming whatever's left - no need for further randomization.
yield from pool
class timer(object):
all_times = defaultdict(float)
def __init__(self, label):
self.label = label
def __enter__(self):
self.tick = time.time()
def __exit__(self, type, value, traceback):
self.tock = time.time()
self.all_times[self.label] += self.tock - self.tick
@classmethod
def print_times(cls):
for k, v in cls.all_times.items():
print("%s: %.3f" % (k, v))

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,7 +0,0 @@
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
# $File: __init__.py
# $Date: Thu Nov 16 14:1006 2017 +0800
# $Author: renyong15 © <mails.tsinghua.edu.cn>
#

View File

@ -1,50 +0,0 @@
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
# $File: game.py
# $Date: Fri Nov 17 15:0745 2017 +0800
# $Author: renyong15 © <mails.tsinghua.edu.cn>
#
import utils
class Game:
def __init__(self, size=19, komi=6.5):
self.size = size
self.komi = 6.5
self.board = [utils.EMPTY] * (self.size * self.size)
self.strategy = None
def _flatten(self, vertex):
x,y = vertex
return (x-1) * self.size + (y-1)
def clear(self):
self.board = [utils.EMPTY] * (self.size * self.size)
def set_size(self, n):
self.size = n
self.clear()
def set_komi(self, k):
self.komi = k
def do_move(self, color, vertex):
if vertex == utils.PASS:
return True
id_ = self._flatten(vertex)
if self.board[id_] == utils.EMPTY:
self.board[id_] = color
return True
else:
return False
def gen_move(self, color):
move = self.strategy.gen_move(color)
return move
#return utils.PASS

View File

@ -1,16 +0,0 @@
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
# $File: utils.py
# $Date: Fri Nov 17 10:2407 2017 +0800
# $Author: renyong15 © <mails.tsinghua.edu.cn>
#
WHITE = -1
BLACK = +1
EMPTY = 0
PASS = (0,0)
RESIGN = "resign"

View File

@ -25,4 +25,4 @@ class rollout_policy(evaluator):
action = np.random.randint(0, self.action_num) action = np.random.randint(0, self.action_num)
state, reward = self.env.step_forward(state, action) state, reward = self.env.step_forward(state, action)
total_reward += reward total_reward += reward
return total_reward return np.ones([self.action_num])/self.action_num, total_reward

View File

@ -5,14 +5,29 @@ import time
c_puct = 5 c_puct = 5
def list2tuple(list):
try:
return tuple(list2tuple(sub) for sub in list)
except TypeError:
return list
def tuple2list(tuple):
try:
return list(tuple2list(sub) for sub in tuple)
except TypeError:
return tuple
class MCTSNode(object): class MCTSNode(object):
def __init__(self, parent, action, state, action_num, prior): def __init__(self, parent, action, state, action_num, prior, inverse=False):
self.parent = parent self.parent = parent
self.action = action self.action = action
self.children = {} self.children = {}
self.state = state self.state = state
self.action_num = action_num self.action_num = action_num
self.prior = prior self.prior = prior
self.inverse = inverse
def selection(self, simulator): def selection(self, simulator):
raise NotImplementedError("Need to implement function selection") raise NotImplementedError("Need to implement function selection")
@ -20,13 +35,10 @@ class MCTSNode(object):
def backpropagation(self, action): def backpropagation(self, action):
raise NotImplementedError("Need to implement function backpropagation") raise NotImplementedError("Need to implement function backpropagation")
def simulation(self, state, evaluator):
raise NotImplementedError("Need to implement function simulation")
class UCTNode(MCTSNode): class UCTNode(MCTSNode):
def __init__(self, parent, action, state, action_num, prior): def __init__(self, parent, action, state, action_num, prior, inverse=False):
super(UCTNode, self).__init__(parent, action, state, action_num, prior) super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
self.Q = np.zeros([action_num]) self.Q = np.zeros([action_num])
self.W = np.zeros([action_num]) self.W = np.zeros([action_num])
self.N = np.zeros([action_num]) self.N = np.zeros([action_num])
@ -49,16 +61,15 @@ class UCTNode(MCTSNode):
self.Q[i] = (self.W[i] + 0.) / self.N[i] self.Q[i] = (self.W[i] + 0.) / self.N[i]
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1.) self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1.)
if self.parent is not None: if self.parent is not None:
if self.inverse:
self.parent.backpropagation(-self.children[action].reward)
else:
self.parent.backpropagation(self.children[action].reward) self.parent.backpropagation(self.children[action].reward)
def simulation(self, evaluator, state):
value = evaluator(state)
return value
class TSNode(MCTSNode): class TSNode(MCTSNode):
def __init__(self, parent, action, state, action_num, prior, method="Gaussian"): def __init__(self, parent, action, state, action_num, prior, method="Gaussian", inverse=False):
super(TSNode, self).__init__(parent, action, state, action_num, prior) super(TSNode, self).__init__(parent, action, state, action_num, prior, inverse)
if method == "Beta": if method == "Beta":
self.alpha = np.ones([action_num]) self.alpha = np.ones([action_num])
self.beta = np.ones([action_num]) self.beta = np.ones([action_num])
@ -73,10 +84,27 @@ class ActionNode:
self.action = action self.action = action
self.children = {} self.children = {}
self.next_state = None self.next_state = None
self.origin_state = None
self.state_type = None
self.reward = 0 self.reward = 0
def type_conversion_to_tuple(self):
if type(self.next_state) is np.ndarray:
self.next_state = self.next_state.tolist()
if type(self.next_state) is list:
self.next_state = list2tuple(self.next_state)
def type_conversion_to_origin(self):
if self.state_type is np.ndarray:
self.next_state = np.array(self.next_state)
if self.state_type is list:
self.next_state = tuple2list(self.next_state)
def selection(self, simulator): def selection(self, simulator):
self.next_state, self.reward = simulator.step_forward(self.parent.state, self.action) self.next_state, self.reward = simulator.step_forward(self.parent.state, self.action)
self.origin_state = self.next_state
self.state_type = type(self.next_state)
self.type_conversion_to_tuple()
if self.next_state is not None: if self.next_state is not None:
if self.next_state in self.children.keys(): if self.next_state in self.children.keys():
return self.children[self.next_state].selection(simulator) return self.children[self.next_state].selection(simulator)
@ -85,14 +113,15 @@ class ActionNode:
else: else:
return self.parent, self.action return self.parent, self.action
def expansion(self, action_num): def expansion(self, evaluator, action_num):
# TODO: Let users/evaluator give the prior # TODO: Let users/evaluator give the prior
if self.next_state is not None: if self.next_state is not None:
prior = np.ones([action_num]) / action_num prior, value = evaluator(self.next_state)
self.children[self.next_state] = UCTNode(self, self.action, self.next_state, action_num, prior) self.children[self.next_state] = UCTNode(self, self.action, self.origin_state, action_num, prior,
return True self.parent.inverse)
return value
else: else:
return False return 0
def backpropagation(self, value): def backpropagation(self, value):
self.reward += value self.reward += value
@ -100,14 +129,16 @@ class ActionNode:
class MCTS: class MCTS:
def __init__(self, simulator, evaluator, root, action_num, prior, method="UCT", max_step=None, max_time=None): def __init__(self, simulator, evaluator, root, action_num, prior, method="UCT", inverse=False, max_step=None,
max_time=None):
self.simulator = simulator self.simulator = simulator
self.evaluator = evaluator self.evaluator = evaluator
self.action_num = action_num self.action_num = action_num
if method == "UCT": if method == "UCT":
self.root = UCTNode(None, None, root, action_num, prior) self.root = UCTNode(None, None, root, action_num, prior, inverse)
if method == "TS": if method == "TS":
self.root = TSNode(None, None, root, action_num, prior) self.root = TSNode(None, None, root, action_num, prior, inverse=inverse)
self.inverse = inverse
if max_step is not None: if max_step is not None:
self.step = 0 self.step = 0
self.max_step = max_step self.max_step = max_step
@ -118,23 +149,15 @@ class MCTS:
raise ValueError("Need a stop criteria!") raise ValueError("Need a stop criteria!")
while (max_step is not None and self.step < self.max_step or max_step is None) \ while (max_step is not None and self.step < self.max_step or max_step is None) \
and (max_time is not None and time.time() - self.start_time < self.max_time or max_time is None): and (max_time is not None and time.time() - self.start_time < self.max_time or max_time is None):
print("Q={}".format(self.root.Q))
print("N={}".format(self.root.N))
print("W={}".format(self.root.W))
print("UCB={}".format(self.root.ucb))
print("\n")
self.expand() self.expand()
if max_step is not None: if max_step is not None:
self.step += 1 self.step += 1
def expand(self): def expand(self):
node, new_action = self.root.selection(self.simulator) node, new_action = self.root.selection(self.simulator)
success = node.children[new_action].expansion(self.action_num) value = node.children[new_action].expansion(self.evaluator, self.action_num)
if success: print("Value:{}".format(value))
value = node.simulation(self.evaluator, node.children[new_action].next_state)
node.children[new_action].backpropagation(value + 0.) node.children[new_action].backpropagation(value + 0.)
else:
node.children[new_action].backpropagation(0.)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -24,7 +24,7 @@ class TestEnv:
else: else:
num = state[0] + 2 ** state[1] * action num = state[0] + 2 ** state[1] * action
step = state[1] + 1 step = state[1] + 1
new_state = (num, step) new_state = [num, step]
if step == self.max_step: if step == self.max_step:
reward = int(np.random.uniform() < self.reward[num]) reward = int(np.random.uniform() < self.reward[num])
else: else:

Binary file not shown.

Binary file not shown.