AlphaGo update
This commit is contained in:
parent
e727ce4d9b
commit
ca0021083f
@ -9,9 +9,10 @@ import tensorflow.contrib.layers as layers
|
||||
import multi_gpu
|
||||
import time
|
||||
|
||||
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
|
||||
|
||||
def residual_block(input, is_training):
|
||||
normalizer_params = {'is_training': is_training,
|
||||
'updates_collections': tf.GraphKeys.UPDATE_OPS}
|
||||
@ -129,54 +130,70 @@ def train():
|
||||
saver.save(sess, result_path + save_path)
|
||||
del data, boards, wins, ps
|
||||
|
||||
def forward(call_number):
|
||||
#checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints"
|
||||
checkpoint_path = "/home/jialian/stuGo/tianshou/stuGo/checkpoints/"
|
||||
board_file = np.genfromtxt("/home/jialian/stuGo/tianshou/leela-zero/src/mcts_nn_files/board_" + call_number, dtype='str');
|
||||
human_board = np.zeros((17, 19, 19))
|
||||
|
||||
#TODO : is it ok to ignore the last channel?
|
||||
for i in range(17):
|
||||
human_board[i] = np.array(list(board_file[i])).reshape(19, 19)
|
||||
#print("============================")
|
||||
#print("human board sum : " + str(np.sum(human_board[-1])))
|
||||
#print("============================")
|
||||
#print(human_board)
|
||||
#print("============================")
|
||||
#rint(human_board)
|
||||
feed_board = human_board.transpose(1, 2, 0).reshape(1, 19, 19, 17)
|
||||
#print(feed_board[:,:,:,-1])
|
||||
#print(feed_board.shape)
|
||||
# def forward(call_number):
|
||||
# # checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints"
|
||||
# checkpoint_path = "/home/jialian/stuGo/tianshou/stuGo/checkpoints/"
|
||||
# board_file = np.genfromtxt("/home/jialian/stuGo/tianshou/leela-zero/src/mcts_nn_files/board_" + call_number,
|
||||
# dtype='str');
|
||||
# human_board = np.zeros((17, 19, 19))
|
||||
#
|
||||
# # TODO : is it ok to ignore the last channel?
|
||||
# for i in range(17):
|
||||
# human_board[i] = np.array(list(board_file[i])).reshape(19, 19)
|
||||
# # print("============================")
|
||||
# # print("human board sum : " + str(np.sum(human_board[-1])))
|
||||
# # print("============================")
|
||||
# # print(human_board)
|
||||
# # print("============================")
|
||||
# # rint(human_board)
|
||||
# feed_board = human_board.transpose(1, 2, 0).reshape(1, 19, 19, 17)
|
||||
# # print(feed_board[:,:,:,-1])
|
||||
# # print(feed_board.shape)
|
||||
#
|
||||
# # npz_board = np.load("/home/yama/rl/tianshou/AlphaGo/data/7f83928932f64a79bc1efdea268698ae.npz")
|
||||
# # print(npz_board["boards"].shape)
|
||||
# # feed_board = npz_board["boards"][10].reshape(-1, 19, 19, 17)
|
||||
# ##print(feed_board)
|
||||
# # show_board = feed_board[0].transpose(2, 0, 1)
|
||||
# # print("board shape : ", show_board.shape)
|
||||
# # print(show_board)
|
||||
#
|
||||
# itflag = False
|
||||
# with multi_gpu.create_session() as sess:
|
||||
# sess.run(tf.global_variables_initializer())
|
||||
# ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
|
||||
# if ckpt_file is not None:
|
||||
# # print('Restoring model from {}...'.format(ckpt_file))
|
||||
# saver.restore(sess, ckpt_file)
|
||||
# else:
|
||||
# raise ValueError("No model loaded")
|
||||
# res = sess.run([tf.nn.softmax(p), v], feed_dict={x: feed_board, is_training: itflag})
|
||||
# # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
|
||||
# # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
|
||||
# # print(np.argmax(res[0]))
|
||||
# np.savetxt(sys.stdout, res[0][0], fmt="%.6f", newline=" ")
|
||||
# np.savetxt(sys.stdout, res[1][0], fmt="%.6f", newline=" ")
|
||||
# pv_file = "/home/jialian/stuGotianshou/leela-zero/src/mcts_nn_files/policy_value"
|
||||
# np.savetxt(pv_file, np.concatenate((res[0][0], res[1][0])), fmt="%.6f", newline=" ")
|
||||
# # np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ")
|
||||
# return res
|
||||
|
||||
#npz_board = np.load("/home/yama/rl/tianshou/AlphaGo/data/7f83928932f64a79bc1efdea268698ae.npz")
|
||||
#print(npz_board["boards"].shape)
|
||||
#feed_board = npz_board["boards"][10].reshape(-1, 19, 19, 17)
|
||||
##print(feed_board)
|
||||
#show_board = feed_board[0].transpose(2, 0, 1)
|
||||
#print("board shape : ", show_board.shape)
|
||||
#print(show_board)
|
||||
|
||||
itflag = False
|
||||
def forward(state):
|
||||
checkpoint_path = "/home/tongzheng/tianshou/AlphaGo/checkpoints/"
|
||||
with multi_gpu.create_session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
|
||||
if ckpt_file is not None:
|
||||
#print('Restoring model from {}...'.format(ckpt_file))
|
||||
print('Restoring model from {}...'.format(ckpt_file))
|
||||
saver.restore(sess, ckpt_file)
|
||||
else:
|
||||
raise ValueError("No model loaded")
|
||||
res = sess.run([tf.nn.softmax(p),v], feed_dict={x:feed_board, is_training:itflag})
|
||||
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
|
||||
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
|
||||
#print(np.argmax(res[0]))
|
||||
np.savetxt(sys.stdout, res[0][0], fmt="%.6f", newline=" ")
|
||||
np.savetxt(sys.stdout, res[1][0], fmt="%.6f", newline=" ")
|
||||
pv_file = "/home/jialian/stuGotianshou/leela-zero/src/mcts_nn_files/policy_value"
|
||||
np.savetxt(pv_file, np.concatenate((res[0][0], res[1][0])), fmt="%.6f", newline=" ")
|
||||
#np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ")
|
||||
return res
|
||||
prior, value = sess.run([tf.nn.softmax(p), v], feed_dict={x: state, is_training: False})
|
||||
return prior, value
|
||||
|
||||
if __name__=='__main__':
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.set_printoptions(threshold='nan')
|
||||
#time.sleep(2)
|
||||
# time.sleep(2)
|
||||
forward(sys.argv[1])
|
||||
|
||||
@ -132,6 +132,8 @@ def train():
|
||||
del save_path
|
||||
del data, boards, wins, ps, batch_num, index
|
||||
gc.collect()
|
||||
|
||||
|
||||
def forward(board):
|
||||
result_path = "./checkpoints"
|
||||
itflag = False
|
||||
@ -144,7 +146,7 @@ def forward(board):
|
||||
print("============================")
|
||||
print("human board sum : " + str(np.sum(human_board)))
|
||||
print("============================")
|
||||
print(board[:,:,:,-1])
|
||||
print(board[:, :, :, -1])
|
||||
itflag = False
|
||||
with multi_gpu.create_session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
from game import Game
|
||||
import utils
|
||||
|
||||
|
||||
class GTPEngine():
|
||||
def __init__(self, **kwargs):
|
||||
self.size = 19
|
||||
@ -27,7 +28,6 @@ class GTPEngine():
|
||||
except:
|
||||
self._version = 2
|
||||
|
||||
|
||||
self.disconnect = False
|
||||
|
||||
self.known_commands = [
|
||||
@ -42,9 +42,6 @@ class GTPEngine():
|
||||
x, y = vertex
|
||||
return "{}{}".format("ABCDEFGHJKLMNOPQRSTYVWYZ"[x - 1], y)
|
||||
|
||||
|
||||
|
||||
|
||||
def _vertex_string2point(self, s):
|
||||
if s is None:
|
||||
return False
|
||||
@ -62,7 +59,6 @@ class GTPEngine():
|
||||
return False
|
||||
return (x, y)
|
||||
|
||||
|
||||
def _parse_color(self, color):
|
||||
if color.lower() in ["b", "black"]:
|
||||
color = utils.BLACK
|
||||
@ -72,21 +68,18 @@ class GTPEngine():
|
||||
color = None
|
||||
return color
|
||||
|
||||
|
||||
def _parse_move(self, move_string):
|
||||
color, move = move_string.split(" ",1)
|
||||
color, move = move_string.split(" ", 1)
|
||||
color = self._parse_color(color)
|
||||
|
||||
point = self._vertex_string2point(move)
|
||||
|
||||
if point and color:
|
||||
return color,point
|
||||
return color, point
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def _parse_res(self, res, id_ = None, success = True):
|
||||
def _parse_res(self, res, id_=None, success=True):
|
||||
if success:
|
||||
if id_:
|
||||
return '={} {}\n\n'.format(id_, res)
|
||||
@ -98,7 +91,6 @@ class GTPEngine():
|
||||
else:
|
||||
return '? {}\n\n'.format(res)
|
||||
|
||||
|
||||
def _parse_cmd(self, message):
|
||||
try:
|
||||
m = message.strip().split(" ", 1)
|
||||
@ -119,19 +111,17 @@ class GTPEngine():
|
||||
return self._parse_res("invaild message", id_, False)
|
||||
|
||||
if cmd in self.known_commands:
|
||||
#dispatch
|
||||
#try:
|
||||
# dispatch
|
||||
# try:
|
||||
if True:
|
||||
res, flag = getattr(self, "cmd_" + cmd)(args)
|
||||
return self._parse_res(res, id_, flag)
|
||||
#except Exception as e:
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# return self._parse_res("command excution failed", id_, False)
|
||||
else:
|
||||
return self._parse_res("unknown command", id_, False)
|
||||
|
||||
|
||||
|
||||
def cmd_protocol_version(self, args, **kwargs):
|
||||
return 2, True
|
||||
|
||||
@ -148,50 +138,45 @@ class GTPEngine():
|
||||
return self.known_commands, True
|
||||
|
||||
def cmd_quit(self, args, **kwargs):
|
||||
return None,True
|
||||
return None, True
|
||||
|
||||
def cmd_boardsize(self, args, **kwargs):
|
||||
if args.isdigit():
|
||||
size = int(args)
|
||||
self.size = size
|
||||
self._game.set_size(size)
|
||||
return None,True
|
||||
return None, True
|
||||
else:
|
||||
return 'non digit size',False
|
||||
return 'non digit size', False
|
||||
|
||||
def cmd_clear_board(self, args, **kwargs):
|
||||
self._game.clear()
|
||||
return None,True
|
||||
return None, True
|
||||
|
||||
def cmd_komi(self, args, **kwargs):
|
||||
try:
|
||||
komi = float(args)
|
||||
self.komi = komi
|
||||
self._game.set_komi(komi)
|
||||
return None,True
|
||||
return None, True
|
||||
except ValueError:
|
||||
raise ValueError("syntax error")
|
||||
|
||||
|
||||
def cmd_play(self, args, **kwargs):
|
||||
move = self._parse_move(args)
|
||||
if move:
|
||||
color, vertex = move
|
||||
res = self._game.do_move(color, vertex)
|
||||
if res:
|
||||
return None,True
|
||||
return None, True
|
||||
else:
|
||||
return None,False
|
||||
return None,True
|
||||
return None, False
|
||||
return None, True
|
||||
|
||||
def cmd_genmove(self, args, **kwargs):
|
||||
color = self._parse_color(args)
|
||||
if color:
|
||||
move = self._game.gen_move(color)
|
||||
return self._vertex_point2string(move),True
|
||||
return self._vertex_point2string(move), True
|
||||
else:
|
||||
return 'unknown player',False
|
||||
|
||||
|
||||
|
||||
|
||||
return 'unknown player', False
|
||||
69
AlphaGo/game.py
Normal file
69
AlphaGo/game.py
Normal file
@ -0,0 +1,69 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim:fenc=utf-8
|
||||
# $File: game.py
|
||||
# $Date: Fri Nov 17 15:0745 2017 +0800
|
||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||
#
|
||||
|
||||
import numpy as np
|
||||
import utils
|
||||
import Network
|
||||
from strategy import strategy
|
||||
from collections import deque
|
||||
|
||||
|
||||
class Game:
|
||||
def __init__(self, size=19, komi=6.5):
|
||||
self.size = size
|
||||
self.komi = 6.5
|
||||
self.board = [utils.EMPTY] * (self.size * self.size)
|
||||
self.strategy = strategy(Network.forward)
|
||||
self.history = deque(maxlen=8)
|
||||
for i in range(8):
|
||||
self.history.append(self.board)
|
||||
|
||||
def _flatten(self, vertex):
|
||||
x, y = vertex
|
||||
return (x - 1) * self.size + (y - 1)
|
||||
|
||||
def clear(self):
|
||||
self.board = [utils.EMPTY] * (self.size * self.size)
|
||||
|
||||
def set_size(self, n):
|
||||
self.size = n
|
||||
self.clear()
|
||||
|
||||
def set_komi(self, k):
|
||||
self.komi = k
|
||||
|
||||
def do_move(self, color, vertex):
|
||||
if vertex == utils.PASS:
|
||||
return True
|
||||
|
||||
id_ = self._flatten(vertex)
|
||||
if self.board[id_] == utils.EMPTY:
|
||||
self.board[id_] = color
|
||||
self.history.append(self.board)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def step_forward(self, state, action):
|
||||
if state[0, 0, 0, -1] == 1:
|
||||
color = 1
|
||||
else:
|
||||
color = -1
|
||||
if action == 361:
|
||||
vertex = (0, 0)
|
||||
else:
|
||||
vertex = (action / 19 + 1, action % 19)
|
||||
self.do_move(color, vertex)
|
||||
new_state = np.concatenate([state[:, :, :, 1:8], self.board == 1, state[:, :, :, 9:16], 1 - state[:, :, :, -1]],
|
||||
axis=3)
|
||||
return new_state, 0
|
||||
|
||||
def gen_move(self, color):
|
||||
move = self.strategy.gen_move(self.history, color)
|
||||
return move
|
||||
# return utils.PASS
|
||||
|
||||
@ -3,7 +3,6 @@ import go
|
||||
import utils
|
||||
|
||||
|
||||
|
||||
def translate_gtp_colors(gtp_color):
|
||||
if gtp_color == gtp.BLACK:
|
||||
return go.BLACK
|
||||
@ -12,6 +11,7 @@ def translate_gtp_colors(gtp_color):
|
||||
else:
|
||||
return go.EMPTY
|
||||
|
||||
|
||||
class GtpInterface(object):
|
||||
def __init__(self):
|
||||
self.size = 9
|
||||
@ -68,19 +68,3 @@ class GtpInterface(object):
|
||||
|
||||
def suggest_move(self, position):
|
||||
raise NotImplementedError
|
||||
|
||||
def make_gtp_instance(strategy_name, read_file):
|
||||
n = PolicyNetwork(use_cpu=True)
|
||||
n.initialize_variables(read_file)
|
||||
if strategy_name == 'random':
|
||||
instance = RandomPlayer()
|
||||
elif strategy_name == 'policy':
|
||||
instance = GreedyPolicyPlayer(n)
|
||||
elif strategy_name == 'randompolicy':
|
||||
instance = RandomPolicyPlayer(n)
|
||||
elif strategy_name == 'mcts':
|
||||
instance = MCTSPlayer(n)
|
||||
else:
|
||||
return None
|
||||
gtp_engine = gtp.Engine(instance)
|
||||
return gtp_engine
|
||||
@ -111,7 +111,7 @@ for n in name:
|
||||
p_ori[:, -1].reshape(-1, 1)],
|
||||
axis=1)
|
||||
concat(block_list, board_aug, p_aug, win_ori)
|
||||
print ("Finished {} with time {}".format(n, time.time()+start))
|
||||
print ("Finished {} with time {}".format(n, time.time() + start))
|
||||
data_num = 0
|
||||
for i in range(slots_num):
|
||||
print("Block {} ".format(block_list[i].block_id) + "Size {}".format(block_list[i].store_num()))
|
||||
|
||||
78
AlphaGo/strategy.py
Normal file
78
AlphaGo/strategy.py
Normal file
@ -0,0 +1,78 @@
|
||||
import numpy as np
|
||||
import utils
|
||||
from collections import deque
|
||||
from tianshou.core.mcts.mcts import MCTS
|
||||
|
||||
|
||||
class GoEnv:
|
||||
def __init__(self, size=19, komi=6.5):
|
||||
self.size = size
|
||||
self.komi = 6.5
|
||||
self.board = [utils.EMPTY] * (self.size * self.size)
|
||||
self.history = deque(maxlen=8)
|
||||
|
||||
def _flatten(self, vertex):
|
||||
x, y = vertex
|
||||
return (x - 1) * self.size + (y - 1)
|
||||
|
||||
def do_move(self, color, vertex):
|
||||
if vertex == utils.PASS:
|
||||
return True
|
||||
|
||||
id_ = self._flatten(vertex)
|
||||
if self.board[id_] == utils.EMPTY:
|
||||
self.board[id_] = color
|
||||
self.history.append(self.board)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def step_forward(self, state, action):
|
||||
# print(state)
|
||||
if state[0, 0, 0, -1] == 1:
|
||||
color = 1
|
||||
else:
|
||||
color = -1
|
||||
if action == 361:
|
||||
vertex = (0, 0)
|
||||
else:
|
||||
vertex = (action / 19 + 1, action % 19)
|
||||
self.do_move(color, vertex)
|
||||
new_state = np.concatenate(
|
||||
[state[:, :, :, 1:8], (np.array(self.board) == 1).reshape(1, 19, 19, 1),
|
||||
state[:, :, :, 9:16], (np.array(self.board) == -1).reshape(1, 19, 19, 1),
|
||||
np.array(1 - state[:, :, :, -1]).reshape(1, 19, 19, 1)],
|
||||
axis=3)
|
||||
return new_state, 0
|
||||
|
||||
|
||||
class strategy(object):
|
||||
def __init__(self, evaluator):
|
||||
self.simulator = GoEnv()
|
||||
self.evaluator = evaluator
|
||||
|
||||
def data_process(self, history, color):
|
||||
state = np.zeros([1, 19, 19, 17])
|
||||
for i in range(8):
|
||||
state[0, :, :, i] = history[i] == 1
|
||||
state[0, :, :, i + 8] = history[i] == -1
|
||||
if color == 1:
|
||||
state[0, :, :, 16] = np.ones([19, 19])
|
||||
if color == -1:
|
||||
state[0, :, :, 16] = np.zeros([19, 19])
|
||||
return state
|
||||
|
||||
def gen_move(self, history, color):
|
||||
self.simulator.history = history
|
||||
self.simulator.board = history[-1]
|
||||
state = self.data_process(history, color)
|
||||
prior = self.evaluator(state)[0]
|
||||
mcts = MCTS(self.simulator, self.evaluator, state, 362, prior, inverse=True, max_step=20)
|
||||
temp = 1
|
||||
p = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
||||
choice = np.random.choice(362, 1, p=p).tolist()[0]
|
||||
if choice == 361:
|
||||
move = (0, 0)
|
||||
else:
|
||||
move = (choice / 19 + 1, choice % 19 + 1)
|
||||
return move
|
||||
@ -8,10 +8,8 @@
|
||||
from game import Game
|
||||
from engine import GTPEngine
|
||||
|
||||
|
||||
|
||||
g = Game()
|
||||
e = GTPEngine(game_obj = g)
|
||||
e = GTPEngine(game_obj=g)
|
||||
res = e.run_cmd('1 protocol_version')
|
||||
print(e.known_commands)
|
||||
print(res)
|
||||
@ -37,4 +35,5 @@ print(res)
|
||||
res = e.run_cmd('8 genmove BLACK')
|
||||
print(res)
|
||||
|
||||
|
||||
res = e.run_cmd('9 genmove WHITE')
|
||||
print(res)
|
||||
119
AlphaGo/utils.py
Normal file
119
AlphaGo/utils.py
Normal file
@ -0,0 +1,119 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim:fenc=utf-8
|
||||
# $File: utils.py
|
||||
# $Date: Fri Nov 17 10:2407 2017 +0800
|
||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||
#
|
||||
|
||||
WHITE = -1
|
||||
BLACK = +1
|
||||
EMPTY = 0
|
||||
|
||||
PASS = (0, 0)
|
||||
RESIGN = "resign"
|
||||
|
||||
from collections import defaultdict
|
||||
import functools
|
||||
import itertools
|
||||
import operator
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
|
||||
import gtp
|
||||
import go
|
||||
|
||||
KGS_COLUMNS = 'ABCDEFGHJKLMNOPQRST'
|
||||
SGF_COLUMNS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
def parse_sgf_to_flat(sgf):
|
||||
return flatten_coords(parse_sgf_coords(sgf))
|
||||
|
||||
def flatten_coords(c):
|
||||
return go.N * c[0] + c[1]
|
||||
|
||||
def unflatten_coords(f):
|
||||
return divmod(f, go.N)
|
||||
|
||||
def parse_sgf_coords(s):
|
||||
'Interprets coords. aa is top left corner; sa is top right corner'
|
||||
if s is None or s == '':
|
||||
return None
|
||||
return SGF_COLUMNS.index(s[1]), SGF_COLUMNS.index(s[0])
|
||||
|
||||
def unparse_sgf_coords(c):
|
||||
if c is None:
|
||||
return ''
|
||||
return SGF_COLUMNS[c[1]] + SGF_COLUMNS[c[0]]
|
||||
|
||||
def parse_kgs_coords(s):
|
||||
'Interprets coords. A1 is bottom left; A9 is top left.'
|
||||
if s == 'pass':
|
||||
return None
|
||||
s = s.upper()
|
||||
col = KGS_COLUMNS.index(s[0])
|
||||
row_from_bottom = int(s[1:]) - 1
|
||||
return go.N - row_from_bottom - 1, col
|
||||
|
||||
def parse_pygtp_coords(vertex):
|
||||
'Interprets coords. (1, 1) is bottom left; (1, 9) is top left.'
|
||||
if vertex in (gtp.PASS, gtp.RESIGN):
|
||||
return None
|
||||
return go.N - vertex[1], vertex[0] - 1
|
||||
|
||||
def unparse_pygtp_coords(c):
|
||||
if c is None:
|
||||
return gtp.PASS
|
||||
return c[1] + 1, go.N - c[0]
|
||||
|
||||
def parse_game_result(result):
|
||||
if re.match(r'[bB]\+', result):
|
||||
return go.BLACK
|
||||
elif re.match(r'[wW]\+', result):
|
||||
return go.WHITE
|
||||
else:
|
||||
return None
|
||||
|
||||
def product(numbers):
|
||||
return functools.reduce(operator.mul, numbers)
|
||||
|
||||
def take_n(n, iterable):
|
||||
return list(itertools.islice(iterable, n))
|
||||
|
||||
def iter_chunks(chunk_size, iterator):
|
||||
while True:
|
||||
next_chunk = take_n(chunk_size, iterator)
|
||||
# If len(iterable) % chunk_size == 0, don't return an empty chunk.
|
||||
if next_chunk:
|
||||
yield next_chunk
|
||||
else:
|
||||
break
|
||||
|
||||
def shuffler(iterator, pool_size=10**5, refill_threshold=0.9):
|
||||
yields_between_refills = round(pool_size * (1 - refill_threshold))
|
||||
# initialize pool; this step may or may not exhaust the iterator.
|
||||
pool = take_n(pool_size, iterator)
|
||||
while True:
|
||||
random.shuffle(pool)
|
||||
for i in range(yields_between_refills):
|
||||
yield pool.pop()
|
||||
next_batch = take_n(yields_between_refills, iterator)
|
||||
if not next_batch:
|
||||
break
|
||||
pool.extend(next_batch)
|
||||
# finish consuming whatever's left - no need for further randomization.
|
||||
yield from pool
|
||||
|
||||
class timer(object):
|
||||
all_times = defaultdict(float)
|
||||
def __init__(self, label):
|
||||
self.label = label
|
||||
def __enter__(self):
|
||||
self.tick = time.time()
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.tock = time.time()
|
||||
self.all_times[self.label] += self.tock - self.tick
|
||||
@classmethod
|
||||
def print_times(cls):
|
||||
for k, v in cls.all_times.items():
|
||||
print("%s: %.3f" % (k, v))
|
||||
BIN
GTP/.game.py.swp
BIN
GTP/.game.py.swp
Binary file not shown.
BIN
GTP/.test.py.swp
BIN
GTP/.test.py.swp
Binary file not shown.
Binary file not shown.
@ -1,7 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim:fenc=utf-8
|
||||
# $File: __init__.py
|
||||
# $Date: Thu Nov 16 14:1006 2017 +0800
|
||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||
#
|
||||
|
||||
50
GTP/game.py
50
GTP/game.py
@ -1,50 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim:fenc=utf-8
|
||||
# $File: game.py
|
||||
# $Date: Fri Nov 17 15:0745 2017 +0800
|
||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||
#
|
||||
|
||||
import utils
|
||||
|
||||
|
||||
class Game:
|
||||
def __init__(self, size=19, komi=6.5):
|
||||
self.size = size
|
||||
self.komi = 6.5
|
||||
self.board = [utils.EMPTY] * (self.size * self.size)
|
||||
self.strategy = None
|
||||
|
||||
def _flatten(self, vertex):
|
||||
x,y = vertex
|
||||
return (x-1) * self.size + (y-1)
|
||||
|
||||
|
||||
def clear(self):
|
||||
self.board = [utils.EMPTY] * (self.size * self.size)
|
||||
|
||||
def set_size(self, n):
|
||||
self.size = n
|
||||
self.clear()
|
||||
|
||||
def set_komi(self, k):
|
||||
self.komi = k
|
||||
|
||||
def do_move(self, color, vertex):
|
||||
if vertex == utils.PASS:
|
||||
return True
|
||||
|
||||
id_ = self._flatten(vertex)
|
||||
if self.board[id_] == utils.EMPTY:
|
||||
self.board[id_] = color
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def gen_move(self, color):
|
||||
move = self.strategy.gen_move(color)
|
||||
return move
|
||||
#return utils.PASS
|
||||
|
||||
|
||||
|
||||
16
GTP/utils.py
16
GTP/utils.py
@ -1,16 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim:fenc=utf-8
|
||||
# $File: utils.py
|
||||
# $Date: Fri Nov 17 10:2407 2017 +0800
|
||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||
#
|
||||
|
||||
WHITE = -1
|
||||
BLACK = +1
|
||||
EMPTY = 0
|
||||
|
||||
PASS = (0,0)
|
||||
RESIGN = "resign"
|
||||
|
||||
|
||||
|
||||
@ -25,4 +25,4 @@ class rollout_policy(evaluator):
|
||||
action = np.random.randint(0, self.action_num)
|
||||
state, reward = self.env.step_forward(state, action)
|
||||
total_reward += reward
|
||||
return total_reward
|
||||
return np.ones([self.action_num])/self.action_num, total_reward
|
||||
|
||||
@ -5,14 +5,29 @@ import time
|
||||
c_puct = 5
|
||||
|
||||
|
||||
def list2tuple(list):
|
||||
try:
|
||||
return tuple(list2tuple(sub) for sub in list)
|
||||
except TypeError:
|
||||
return list
|
||||
|
||||
|
||||
def tuple2list(tuple):
|
||||
try:
|
||||
return list(tuple2list(sub) for sub in tuple)
|
||||
except TypeError:
|
||||
return tuple
|
||||
|
||||
|
||||
class MCTSNode(object):
|
||||
def __init__(self, parent, action, state, action_num, prior):
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||
self.parent = parent
|
||||
self.action = action
|
||||
self.children = {}
|
||||
self.state = state
|
||||
self.action_num = action_num
|
||||
self.prior = prior
|
||||
self.inverse = inverse
|
||||
|
||||
def selection(self, simulator):
|
||||
raise NotImplementedError("Need to implement function selection")
|
||||
@ -20,13 +35,10 @@ class MCTSNode(object):
|
||||
def backpropagation(self, action):
|
||||
raise NotImplementedError("Need to implement function backpropagation")
|
||||
|
||||
def simulation(self, state, evaluator):
|
||||
raise NotImplementedError("Need to implement function simulation")
|
||||
|
||||
|
||||
class UCTNode(MCTSNode):
|
||||
def __init__(self, parent, action, state, action_num, prior):
|
||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior)
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||
self.Q = np.zeros([action_num])
|
||||
self.W = np.zeros([action_num])
|
||||
self.N = np.zeros([action_num])
|
||||
@ -49,16 +61,15 @@ class UCTNode(MCTSNode):
|
||||
self.Q[i] = (self.W[i] + 0.) / self.N[i]
|
||||
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1.)
|
||||
if self.parent is not None:
|
||||
if self.inverse:
|
||||
self.parent.backpropagation(-self.children[action].reward)
|
||||
else:
|
||||
self.parent.backpropagation(self.children[action].reward)
|
||||
|
||||
def simulation(self, evaluator, state):
|
||||
value = evaluator(state)
|
||||
return value
|
||||
|
||||
|
||||
class TSNode(MCTSNode):
|
||||
def __init__(self, parent, action, state, action_num, prior, method="Gaussian"):
|
||||
super(TSNode, self).__init__(parent, action, state, action_num, prior)
|
||||
def __init__(self, parent, action, state, action_num, prior, method="Gaussian", inverse=False):
|
||||
super(TSNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||
if method == "Beta":
|
||||
self.alpha = np.ones([action_num])
|
||||
self.beta = np.ones([action_num])
|
||||
@ -73,10 +84,27 @@ class ActionNode:
|
||||
self.action = action
|
||||
self.children = {}
|
||||
self.next_state = None
|
||||
self.origin_state = None
|
||||
self.state_type = None
|
||||
self.reward = 0
|
||||
|
||||
def type_conversion_to_tuple(self):
|
||||
if type(self.next_state) is np.ndarray:
|
||||
self.next_state = self.next_state.tolist()
|
||||
if type(self.next_state) is list:
|
||||
self.next_state = list2tuple(self.next_state)
|
||||
|
||||
def type_conversion_to_origin(self):
|
||||
if self.state_type is np.ndarray:
|
||||
self.next_state = np.array(self.next_state)
|
||||
if self.state_type is list:
|
||||
self.next_state = tuple2list(self.next_state)
|
||||
|
||||
def selection(self, simulator):
|
||||
self.next_state, self.reward = simulator.step_forward(self.parent.state, self.action)
|
||||
self.origin_state = self.next_state
|
||||
self.state_type = type(self.next_state)
|
||||
self.type_conversion_to_tuple()
|
||||
if self.next_state is not None:
|
||||
if self.next_state in self.children.keys():
|
||||
return self.children[self.next_state].selection(simulator)
|
||||
@ -85,14 +113,15 @@ class ActionNode:
|
||||
else:
|
||||
return self.parent, self.action
|
||||
|
||||
def expansion(self, action_num):
|
||||
def expansion(self, evaluator, action_num):
|
||||
# TODO: Let users/evaluator give the prior
|
||||
if self.next_state is not None:
|
||||
prior = np.ones([action_num]) / action_num
|
||||
self.children[self.next_state] = UCTNode(self, self.action, self.next_state, action_num, prior)
|
||||
return True
|
||||
prior, value = evaluator(self.next_state)
|
||||
self.children[self.next_state] = UCTNode(self, self.action, self.origin_state, action_num, prior,
|
||||
self.parent.inverse)
|
||||
return value
|
||||
else:
|
||||
return False
|
||||
return 0
|
||||
|
||||
def backpropagation(self, value):
|
||||
self.reward += value
|
||||
@ -100,14 +129,16 @@ class ActionNode:
|
||||
|
||||
|
||||
class MCTS:
|
||||
def __init__(self, simulator, evaluator, root, action_num, prior, method="UCT", max_step=None, max_time=None):
|
||||
def __init__(self, simulator, evaluator, root, action_num, prior, method="UCT", inverse=False, max_step=None,
|
||||
max_time=None):
|
||||
self.simulator = simulator
|
||||
self.evaluator = evaluator
|
||||
self.action_num = action_num
|
||||
if method == "UCT":
|
||||
self.root = UCTNode(None, None, root, action_num, prior)
|
||||
self.root = UCTNode(None, None, root, action_num, prior, inverse)
|
||||
if method == "TS":
|
||||
self.root = TSNode(None, None, root, action_num, prior)
|
||||
self.root = TSNode(None, None, root, action_num, prior, inverse=inverse)
|
||||
self.inverse = inverse
|
||||
if max_step is not None:
|
||||
self.step = 0
|
||||
self.max_step = max_step
|
||||
@ -118,23 +149,15 @@ class MCTS:
|
||||
raise ValueError("Need a stop criteria!")
|
||||
while (max_step is not None and self.step < self.max_step or max_step is None) \
|
||||
and (max_time is not None and time.time() - self.start_time < self.max_time or max_time is None):
|
||||
print("Q={}".format(self.root.Q))
|
||||
print("N={}".format(self.root.N))
|
||||
print("W={}".format(self.root.W))
|
||||
print("UCB={}".format(self.root.ucb))
|
||||
print("\n")
|
||||
self.expand()
|
||||
if max_step is not None:
|
||||
self.step += 1
|
||||
|
||||
def expand(self):
|
||||
node, new_action = self.root.selection(self.simulator)
|
||||
success = node.children[new_action].expansion(self.action_num)
|
||||
if success:
|
||||
value = node.simulation(self.evaluator, node.children[new_action].next_state)
|
||||
value = node.children[new_action].expansion(self.evaluator, self.action_num)
|
||||
print("Value:{}".format(value))
|
||||
node.children[new_action].backpropagation(value + 0.)
|
||||
else:
|
||||
node.children[new_action].backpropagation(0.)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -24,7 +24,7 @@ class TestEnv:
|
||||
else:
|
||||
num = state[0] + 2 ** state[1] * action
|
||||
step = state[1] + 1
|
||||
new_state = (num, step)
|
||||
new_state = [num, step]
|
||||
if step == self.max_step:
|
||||
reward = int(np.random.uniform() < self.reward[num])
|
||||
else:
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user