AlphaGo update
This commit is contained in:
parent
e727ce4d9b
commit
ca0021083f
@ -12,6 +12,7 @@ import time
|
|||||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
|
|
||||||
|
|
||||||
def residual_block(input, is_training):
|
def residual_block(input, is_training):
|
||||||
normalizer_params = {'is_training': is_training,
|
normalizer_params = {'is_training': is_training,
|
||||||
'updates_collections': tf.GraphKeys.UPDATE_OPS}
|
'updates_collections': tf.GraphKeys.UPDATE_OPS}
|
||||||
@ -129,52 +130,68 @@ def train():
|
|||||||
saver.save(sess, result_path + save_path)
|
saver.save(sess, result_path + save_path)
|
||||||
del data, boards, wins, ps
|
del data, boards, wins, ps
|
||||||
|
|
||||||
def forward(call_number):
|
|
||||||
#checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints"
|
|
||||||
checkpoint_path = "/home/jialian/stuGo/tianshou/stuGo/checkpoints/"
|
|
||||||
board_file = np.genfromtxt("/home/jialian/stuGo/tianshou/leela-zero/src/mcts_nn_files/board_" + call_number, dtype='str');
|
|
||||||
human_board = np.zeros((17, 19, 19))
|
|
||||||
|
|
||||||
#TODO : is it ok to ignore the last channel?
|
# def forward(call_number):
|
||||||
for i in range(17):
|
# # checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints"
|
||||||
human_board[i] = np.array(list(board_file[i])).reshape(19, 19)
|
# checkpoint_path = "/home/jialian/stuGo/tianshou/stuGo/checkpoints/"
|
||||||
#print("============================")
|
# board_file = np.genfromtxt("/home/jialian/stuGo/tianshou/leela-zero/src/mcts_nn_files/board_" + call_number,
|
||||||
#print("human board sum : " + str(np.sum(human_board[-1])))
|
# dtype='str');
|
||||||
#print("============================")
|
# human_board = np.zeros((17, 19, 19))
|
||||||
#print(human_board)
|
#
|
||||||
#print("============================")
|
# # TODO : is it ok to ignore the last channel?
|
||||||
#rint(human_board)
|
# for i in range(17):
|
||||||
feed_board = human_board.transpose(1, 2, 0).reshape(1, 19, 19, 17)
|
# human_board[i] = np.array(list(board_file[i])).reshape(19, 19)
|
||||||
#print(feed_board[:,:,:,-1])
|
# # print("============================")
|
||||||
#print(feed_board.shape)
|
# # print("human board sum : " + str(np.sum(human_board[-1])))
|
||||||
|
# # print("============================")
|
||||||
|
# # print(human_board)
|
||||||
|
# # print("============================")
|
||||||
|
# # rint(human_board)
|
||||||
|
# feed_board = human_board.transpose(1, 2, 0).reshape(1, 19, 19, 17)
|
||||||
|
# # print(feed_board[:,:,:,-1])
|
||||||
|
# # print(feed_board.shape)
|
||||||
|
#
|
||||||
|
# # npz_board = np.load("/home/yama/rl/tianshou/AlphaGo/data/7f83928932f64a79bc1efdea268698ae.npz")
|
||||||
|
# # print(npz_board["boards"].shape)
|
||||||
|
# # feed_board = npz_board["boards"][10].reshape(-1, 19, 19, 17)
|
||||||
|
# ##print(feed_board)
|
||||||
|
# # show_board = feed_board[0].transpose(2, 0, 1)
|
||||||
|
# # print("board shape : ", show_board.shape)
|
||||||
|
# # print(show_board)
|
||||||
|
#
|
||||||
|
# itflag = False
|
||||||
|
# with multi_gpu.create_session() as sess:
|
||||||
|
# sess.run(tf.global_variables_initializer())
|
||||||
|
# ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
|
||||||
|
# if ckpt_file is not None:
|
||||||
|
# # print('Restoring model from {}...'.format(ckpt_file))
|
||||||
|
# saver.restore(sess, ckpt_file)
|
||||||
|
# else:
|
||||||
|
# raise ValueError("No model loaded")
|
||||||
|
# res = sess.run([tf.nn.softmax(p), v], feed_dict={x: feed_board, is_training: itflag})
|
||||||
|
# # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
|
||||||
|
# # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
|
||||||
|
# # print(np.argmax(res[0]))
|
||||||
|
# np.savetxt(sys.stdout, res[0][0], fmt="%.6f", newline=" ")
|
||||||
|
# np.savetxt(sys.stdout, res[1][0], fmt="%.6f", newline=" ")
|
||||||
|
# pv_file = "/home/jialian/stuGotianshou/leela-zero/src/mcts_nn_files/policy_value"
|
||||||
|
# np.savetxt(pv_file, np.concatenate((res[0][0], res[1][0])), fmt="%.6f", newline=" ")
|
||||||
|
# # np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ")
|
||||||
|
# return res
|
||||||
|
|
||||||
#npz_board = np.load("/home/yama/rl/tianshou/AlphaGo/data/7f83928932f64a79bc1efdea268698ae.npz")
|
def forward(state):
|
||||||
#print(npz_board["boards"].shape)
|
checkpoint_path = "/home/tongzheng/tianshou/AlphaGo/checkpoints/"
|
||||||
#feed_board = npz_board["boards"][10].reshape(-1, 19, 19, 17)
|
|
||||||
##print(feed_board)
|
|
||||||
#show_board = feed_board[0].transpose(2, 0, 1)
|
|
||||||
#print("board shape : ", show_board.shape)
|
|
||||||
#print(show_board)
|
|
||||||
|
|
||||||
itflag = False
|
|
||||||
with multi_gpu.create_session() as sess:
|
with multi_gpu.create_session() as sess:
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
|
ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
|
||||||
if ckpt_file is not None:
|
if ckpt_file is not None:
|
||||||
#print('Restoring model from {}...'.format(ckpt_file))
|
print('Restoring model from {}...'.format(ckpt_file))
|
||||||
saver.restore(sess, ckpt_file)
|
saver.restore(sess, ckpt_file)
|
||||||
else:
|
else:
|
||||||
raise ValueError("No model loaded")
|
raise ValueError("No model loaded")
|
||||||
res = sess.run([tf.nn.softmax(p),v], feed_dict={x:feed_board, is_training:itflag})
|
prior, value = sess.run([tf.nn.softmax(p), v], feed_dict={x: state, is_training: False})
|
||||||
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
|
return prior, value
|
||||||
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
|
|
||||||
#print(np.argmax(res[0]))
|
|
||||||
np.savetxt(sys.stdout, res[0][0], fmt="%.6f", newline=" ")
|
|
||||||
np.savetxt(sys.stdout, res[1][0], fmt="%.6f", newline=" ")
|
|
||||||
pv_file = "/home/jialian/stuGotianshou/leela-zero/src/mcts_nn_files/policy_value"
|
|
||||||
np.savetxt(pv_file, np.concatenate((res[0][0], res[1][0])), fmt="%.6f", newline=" ")
|
|
||||||
#np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ")
|
|
||||||
return res
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
np.set_printoptions(threshold='nan')
|
np.set_printoptions(threshold='nan')
|
||||||
|
|||||||
@ -132,6 +132,8 @@ def train():
|
|||||||
del save_path
|
del save_path
|
||||||
del data, boards, wins, ps, batch_num, index
|
del data, boards, wins, ps, batch_num, index
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
def forward(board):
|
def forward(board):
|
||||||
result_path = "./checkpoints"
|
result_path = "./checkpoints"
|
||||||
itflag = False
|
itflag = False
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
from game import Game
|
from game import Game
|
||||||
import utils
|
import utils
|
||||||
|
|
||||||
|
|
||||||
class GTPEngine():
|
class GTPEngine():
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.size = 19
|
self.size = 19
|
||||||
@ -27,7 +28,6 @@ class GTPEngine():
|
|||||||
except:
|
except:
|
||||||
self._version = 2
|
self._version = 2
|
||||||
|
|
||||||
|
|
||||||
self.disconnect = False
|
self.disconnect = False
|
||||||
|
|
||||||
self.known_commands = [
|
self.known_commands = [
|
||||||
@ -42,9 +42,6 @@ class GTPEngine():
|
|||||||
x, y = vertex
|
x, y = vertex
|
||||||
return "{}{}".format("ABCDEFGHJKLMNOPQRSTYVWYZ"[x - 1], y)
|
return "{}{}".format("ABCDEFGHJKLMNOPQRSTYVWYZ"[x - 1], y)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _vertex_string2point(self, s):
|
def _vertex_string2point(self, s):
|
||||||
if s is None:
|
if s is None:
|
||||||
return False
|
return False
|
||||||
@ -62,7 +59,6 @@ class GTPEngine():
|
|||||||
return False
|
return False
|
||||||
return (x, y)
|
return (x, y)
|
||||||
|
|
||||||
|
|
||||||
def _parse_color(self, color):
|
def _parse_color(self, color):
|
||||||
if color.lower() in ["b", "black"]:
|
if color.lower() in ["b", "black"]:
|
||||||
color = utils.BLACK
|
color = utils.BLACK
|
||||||
@ -72,7 +68,6 @@ class GTPEngine():
|
|||||||
color = None
|
color = None
|
||||||
return color
|
return color
|
||||||
|
|
||||||
|
|
||||||
def _parse_move(self, move_string):
|
def _parse_move(self, move_string):
|
||||||
color, move = move_string.split(" ", 1)
|
color, move = move_string.split(" ", 1)
|
||||||
color = self._parse_color(color)
|
color = self._parse_color(color)
|
||||||
@ -84,8 +79,6 @@ class GTPEngine():
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_res(self, res, id_=None, success=True):
|
def _parse_res(self, res, id_=None, success=True):
|
||||||
if success:
|
if success:
|
||||||
if id_:
|
if id_:
|
||||||
@ -98,7 +91,6 @@ class GTPEngine():
|
|||||||
else:
|
else:
|
||||||
return '? {}\n\n'.format(res)
|
return '? {}\n\n'.format(res)
|
||||||
|
|
||||||
|
|
||||||
def _parse_cmd(self, message):
|
def _parse_cmd(self, message):
|
||||||
try:
|
try:
|
||||||
m = message.strip().split(" ", 1)
|
m = message.strip().split(" ", 1)
|
||||||
@ -130,8 +122,6 @@ class GTPEngine():
|
|||||||
else:
|
else:
|
||||||
return self._parse_res("unknown command", id_, False)
|
return self._parse_res("unknown command", id_, False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def cmd_protocol_version(self, args, **kwargs):
|
def cmd_protocol_version(self, args, **kwargs):
|
||||||
return 2, True
|
return 2, True
|
||||||
|
|
||||||
@ -172,7 +162,6 @@ class GTPEngine():
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError("syntax error")
|
raise ValueError("syntax error")
|
||||||
|
|
||||||
|
|
||||||
def cmd_play(self, args, **kwargs):
|
def cmd_play(self, args, **kwargs):
|
||||||
move = self._parse_move(args)
|
move = self._parse_move(args)
|
||||||
if move:
|
if move:
|
||||||
@ -191,7 +180,3 @@ class GTPEngine():
|
|||||||
return self._vertex_point2string(move), True
|
return self._vertex_point2string(move), True
|
||||||
else:
|
else:
|
||||||
return 'unknown player', False
|
return 'unknown player', False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
69
AlphaGo/game.py
Normal file
69
AlphaGo/game.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# vim:fenc=utf-8
|
||||||
|
# $File: game.py
|
||||||
|
# $Date: Fri Nov 17 15:0745 2017 +0800
|
||||||
|
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||||
|
#
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import utils
|
||||||
|
import Network
|
||||||
|
from strategy import strategy
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
class Game:
|
||||||
|
def __init__(self, size=19, komi=6.5):
|
||||||
|
self.size = size
|
||||||
|
self.komi = 6.5
|
||||||
|
self.board = [utils.EMPTY] * (self.size * self.size)
|
||||||
|
self.strategy = strategy(Network.forward)
|
||||||
|
self.history = deque(maxlen=8)
|
||||||
|
for i in range(8):
|
||||||
|
self.history.append(self.board)
|
||||||
|
|
||||||
|
def _flatten(self, vertex):
|
||||||
|
x, y = vertex
|
||||||
|
return (x - 1) * self.size + (y - 1)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.board = [utils.EMPTY] * (self.size * self.size)
|
||||||
|
|
||||||
|
def set_size(self, n):
|
||||||
|
self.size = n
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
def set_komi(self, k):
|
||||||
|
self.komi = k
|
||||||
|
|
||||||
|
def do_move(self, color, vertex):
|
||||||
|
if vertex == utils.PASS:
|
||||||
|
return True
|
||||||
|
|
||||||
|
id_ = self._flatten(vertex)
|
||||||
|
if self.board[id_] == utils.EMPTY:
|
||||||
|
self.board[id_] = color
|
||||||
|
self.history.append(self.board)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def step_forward(self, state, action):
|
||||||
|
if state[0, 0, 0, -1] == 1:
|
||||||
|
color = 1
|
||||||
|
else:
|
||||||
|
color = -1
|
||||||
|
if action == 361:
|
||||||
|
vertex = (0, 0)
|
||||||
|
else:
|
||||||
|
vertex = (action / 19 + 1, action % 19)
|
||||||
|
self.do_move(color, vertex)
|
||||||
|
new_state = np.concatenate([state[:, :, :, 1:8], self.board == 1, state[:, :, :, 9:16], 1 - state[:, :, :, -1]],
|
||||||
|
axis=3)
|
||||||
|
return new_state, 0
|
||||||
|
|
||||||
|
def gen_move(self, color):
|
||||||
|
move = self.strategy.gen_move(self.history, color)
|
||||||
|
return move
|
||||||
|
# return utils.PASS
|
||||||
|
|
||||||
@ -3,7 +3,6 @@ import go
|
|||||||
import utils
|
import utils
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def translate_gtp_colors(gtp_color):
|
def translate_gtp_colors(gtp_color):
|
||||||
if gtp_color == gtp.BLACK:
|
if gtp_color == gtp.BLACK:
|
||||||
return go.BLACK
|
return go.BLACK
|
||||||
@ -12,6 +11,7 @@ def translate_gtp_colors(gtp_color):
|
|||||||
else:
|
else:
|
||||||
return go.EMPTY
|
return go.EMPTY
|
||||||
|
|
||||||
|
|
||||||
class GtpInterface(object):
|
class GtpInterface(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.size = 9
|
self.size = 9
|
||||||
@ -68,19 +68,3 @@ class GtpInterface(object):
|
|||||||
|
|
||||||
def suggest_move(self, position):
|
def suggest_move(self, position):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def make_gtp_instance(strategy_name, read_file):
|
|
||||||
n = PolicyNetwork(use_cpu=True)
|
|
||||||
n.initialize_variables(read_file)
|
|
||||||
if strategy_name == 'random':
|
|
||||||
instance = RandomPlayer()
|
|
||||||
elif strategy_name == 'policy':
|
|
||||||
instance = GreedyPolicyPlayer(n)
|
|
||||||
elif strategy_name == 'randompolicy':
|
|
||||||
instance = RandomPolicyPlayer(n)
|
|
||||||
elif strategy_name == 'mcts':
|
|
||||||
instance = MCTSPlayer(n)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
gtp_engine = gtp.Engine(instance)
|
|
||||||
return gtp_engine
|
|
||||||
78
AlphaGo/strategy.py
Normal file
78
AlphaGo/strategy.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import numpy as np
|
||||||
|
import utils
|
||||||
|
from collections import deque
|
||||||
|
from tianshou.core.mcts.mcts import MCTS
|
||||||
|
|
||||||
|
|
||||||
|
class GoEnv:
|
||||||
|
def __init__(self, size=19, komi=6.5):
|
||||||
|
self.size = size
|
||||||
|
self.komi = 6.5
|
||||||
|
self.board = [utils.EMPTY] * (self.size * self.size)
|
||||||
|
self.history = deque(maxlen=8)
|
||||||
|
|
||||||
|
def _flatten(self, vertex):
|
||||||
|
x, y = vertex
|
||||||
|
return (x - 1) * self.size + (y - 1)
|
||||||
|
|
||||||
|
def do_move(self, color, vertex):
|
||||||
|
if vertex == utils.PASS:
|
||||||
|
return True
|
||||||
|
|
||||||
|
id_ = self._flatten(vertex)
|
||||||
|
if self.board[id_] == utils.EMPTY:
|
||||||
|
self.board[id_] = color
|
||||||
|
self.history.append(self.board)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def step_forward(self, state, action):
|
||||||
|
# print(state)
|
||||||
|
if state[0, 0, 0, -1] == 1:
|
||||||
|
color = 1
|
||||||
|
else:
|
||||||
|
color = -1
|
||||||
|
if action == 361:
|
||||||
|
vertex = (0, 0)
|
||||||
|
else:
|
||||||
|
vertex = (action / 19 + 1, action % 19)
|
||||||
|
self.do_move(color, vertex)
|
||||||
|
new_state = np.concatenate(
|
||||||
|
[state[:, :, :, 1:8], (np.array(self.board) == 1).reshape(1, 19, 19, 1),
|
||||||
|
state[:, :, :, 9:16], (np.array(self.board) == -1).reshape(1, 19, 19, 1),
|
||||||
|
np.array(1 - state[:, :, :, -1]).reshape(1, 19, 19, 1)],
|
||||||
|
axis=3)
|
||||||
|
return new_state, 0
|
||||||
|
|
||||||
|
|
||||||
|
class strategy(object):
|
||||||
|
def __init__(self, evaluator):
|
||||||
|
self.simulator = GoEnv()
|
||||||
|
self.evaluator = evaluator
|
||||||
|
|
||||||
|
def data_process(self, history, color):
|
||||||
|
state = np.zeros([1, 19, 19, 17])
|
||||||
|
for i in range(8):
|
||||||
|
state[0, :, :, i] = history[i] == 1
|
||||||
|
state[0, :, :, i + 8] = history[i] == -1
|
||||||
|
if color == 1:
|
||||||
|
state[0, :, :, 16] = np.ones([19, 19])
|
||||||
|
if color == -1:
|
||||||
|
state[0, :, :, 16] = np.zeros([19, 19])
|
||||||
|
return state
|
||||||
|
|
||||||
|
def gen_move(self, history, color):
|
||||||
|
self.simulator.history = history
|
||||||
|
self.simulator.board = history[-1]
|
||||||
|
state = self.data_process(history, color)
|
||||||
|
prior = self.evaluator(state)[0]
|
||||||
|
mcts = MCTS(self.simulator, self.evaluator, state, 362, prior, inverse=True, max_step=20)
|
||||||
|
temp = 1
|
||||||
|
p = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
||||||
|
choice = np.random.choice(362, 1, p=p).tolist()[0]
|
||||||
|
if choice == 361:
|
||||||
|
move = (0, 0)
|
||||||
|
else:
|
||||||
|
move = (choice / 19 + 1, choice % 19 + 1)
|
||||||
|
return move
|
||||||
@ -8,8 +8,6 @@
|
|||||||
from game import Game
|
from game import Game
|
||||||
from engine import GTPEngine
|
from engine import GTPEngine
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
g = Game()
|
g = Game()
|
||||||
e = GTPEngine(game_obj=g)
|
e = GTPEngine(game_obj=g)
|
||||||
res = e.run_cmd('1 protocol_version')
|
res = e.run_cmd('1 protocol_version')
|
||||||
@ -37,4 +35,5 @@ print(res)
|
|||||||
res = e.run_cmd('8 genmove BLACK')
|
res = e.run_cmd('8 genmove BLACK')
|
||||||
print(res)
|
print(res)
|
||||||
|
|
||||||
|
res = e.run_cmd('9 genmove WHITE')
|
||||||
|
print(res)
|
||||||
119
AlphaGo/utils.py
Normal file
119
AlphaGo/utils.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# vim:fenc=utf-8
|
||||||
|
# $File: utils.py
|
||||||
|
# $Date: Fri Nov 17 10:2407 2017 +0800
|
||||||
|
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||||
|
#
|
||||||
|
|
||||||
|
WHITE = -1
|
||||||
|
BLACK = +1
|
||||||
|
EMPTY = 0
|
||||||
|
|
||||||
|
PASS = (0, 0)
|
||||||
|
RESIGN = "resign"
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
import functools
|
||||||
|
import itertools
|
||||||
|
import operator
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
|
||||||
|
import gtp
|
||||||
|
import go
|
||||||
|
|
||||||
|
KGS_COLUMNS = 'ABCDEFGHJKLMNOPQRST'
|
||||||
|
SGF_COLUMNS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||||
|
|
||||||
|
def parse_sgf_to_flat(sgf):
|
||||||
|
return flatten_coords(parse_sgf_coords(sgf))
|
||||||
|
|
||||||
|
def flatten_coords(c):
|
||||||
|
return go.N * c[0] + c[1]
|
||||||
|
|
||||||
|
def unflatten_coords(f):
|
||||||
|
return divmod(f, go.N)
|
||||||
|
|
||||||
|
def parse_sgf_coords(s):
|
||||||
|
'Interprets coords. aa is top left corner; sa is top right corner'
|
||||||
|
if s is None or s == '':
|
||||||
|
return None
|
||||||
|
return SGF_COLUMNS.index(s[1]), SGF_COLUMNS.index(s[0])
|
||||||
|
|
||||||
|
def unparse_sgf_coords(c):
|
||||||
|
if c is None:
|
||||||
|
return ''
|
||||||
|
return SGF_COLUMNS[c[1]] + SGF_COLUMNS[c[0]]
|
||||||
|
|
||||||
|
def parse_kgs_coords(s):
|
||||||
|
'Interprets coords. A1 is bottom left; A9 is top left.'
|
||||||
|
if s == 'pass':
|
||||||
|
return None
|
||||||
|
s = s.upper()
|
||||||
|
col = KGS_COLUMNS.index(s[0])
|
||||||
|
row_from_bottom = int(s[1:]) - 1
|
||||||
|
return go.N - row_from_bottom - 1, col
|
||||||
|
|
||||||
|
def parse_pygtp_coords(vertex):
|
||||||
|
'Interprets coords. (1, 1) is bottom left; (1, 9) is top left.'
|
||||||
|
if vertex in (gtp.PASS, gtp.RESIGN):
|
||||||
|
return None
|
||||||
|
return go.N - vertex[1], vertex[0] - 1
|
||||||
|
|
||||||
|
def unparse_pygtp_coords(c):
|
||||||
|
if c is None:
|
||||||
|
return gtp.PASS
|
||||||
|
return c[1] + 1, go.N - c[0]
|
||||||
|
|
||||||
|
def parse_game_result(result):
|
||||||
|
if re.match(r'[bB]\+', result):
|
||||||
|
return go.BLACK
|
||||||
|
elif re.match(r'[wW]\+', result):
|
||||||
|
return go.WHITE
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def product(numbers):
|
||||||
|
return functools.reduce(operator.mul, numbers)
|
||||||
|
|
||||||
|
def take_n(n, iterable):
|
||||||
|
return list(itertools.islice(iterable, n))
|
||||||
|
|
||||||
|
def iter_chunks(chunk_size, iterator):
|
||||||
|
while True:
|
||||||
|
next_chunk = take_n(chunk_size, iterator)
|
||||||
|
# If len(iterable) % chunk_size == 0, don't return an empty chunk.
|
||||||
|
if next_chunk:
|
||||||
|
yield next_chunk
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
def shuffler(iterator, pool_size=10**5, refill_threshold=0.9):
|
||||||
|
yields_between_refills = round(pool_size * (1 - refill_threshold))
|
||||||
|
# initialize pool; this step may or may not exhaust the iterator.
|
||||||
|
pool = take_n(pool_size, iterator)
|
||||||
|
while True:
|
||||||
|
random.shuffle(pool)
|
||||||
|
for i in range(yields_between_refills):
|
||||||
|
yield pool.pop()
|
||||||
|
next_batch = take_n(yields_between_refills, iterator)
|
||||||
|
if not next_batch:
|
||||||
|
break
|
||||||
|
pool.extend(next_batch)
|
||||||
|
# finish consuming whatever's left - no need for further randomization.
|
||||||
|
yield from pool
|
||||||
|
|
||||||
|
class timer(object):
|
||||||
|
all_times = defaultdict(float)
|
||||||
|
def __init__(self, label):
|
||||||
|
self.label = label
|
||||||
|
def __enter__(self):
|
||||||
|
self.tick = time.time()
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
self.tock = time.time()
|
||||||
|
self.all_times[self.label] += self.tock - self.tick
|
||||||
|
@classmethod
|
||||||
|
def print_times(cls):
|
||||||
|
for k, v in cls.all_times.items():
|
||||||
|
print("%s: %.3f" % (k, v))
|
||||||
BIN
GTP/.game.py.swp
BIN
GTP/.game.py.swp
Binary file not shown.
BIN
GTP/.test.py.swp
BIN
GTP/.test.py.swp
Binary file not shown.
Binary file not shown.
@ -1,7 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
# vim:fenc=utf-8
|
|
||||||
# $File: __init__.py
|
|
||||||
# $Date: Thu Nov 16 14:1006 2017 +0800
|
|
||||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
|
||||||
#
|
|
||||||
|
|
||||||
50
GTP/game.py
50
GTP/game.py
@ -1,50 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
# vim:fenc=utf-8
|
|
||||||
# $File: game.py
|
|
||||||
# $Date: Fri Nov 17 15:0745 2017 +0800
|
|
||||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
|
||||||
#
|
|
||||||
|
|
||||||
import utils
|
|
||||||
|
|
||||||
|
|
||||||
class Game:
|
|
||||||
def __init__(self, size=19, komi=6.5):
|
|
||||||
self.size = size
|
|
||||||
self.komi = 6.5
|
|
||||||
self.board = [utils.EMPTY] * (self.size * self.size)
|
|
||||||
self.strategy = None
|
|
||||||
|
|
||||||
def _flatten(self, vertex):
|
|
||||||
x,y = vertex
|
|
||||||
return (x-1) * self.size + (y-1)
|
|
||||||
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
self.board = [utils.EMPTY] * (self.size * self.size)
|
|
||||||
|
|
||||||
def set_size(self, n):
|
|
||||||
self.size = n
|
|
||||||
self.clear()
|
|
||||||
|
|
||||||
def set_komi(self, k):
|
|
||||||
self.komi = k
|
|
||||||
|
|
||||||
def do_move(self, color, vertex):
|
|
||||||
if vertex == utils.PASS:
|
|
||||||
return True
|
|
||||||
|
|
||||||
id_ = self._flatten(vertex)
|
|
||||||
if self.board[id_] == utils.EMPTY:
|
|
||||||
self.board[id_] = color
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def gen_move(self, color):
|
|
||||||
move = self.strategy.gen_move(color)
|
|
||||||
return move
|
|
||||||
#return utils.PASS
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
16
GTP/utils.py
16
GTP/utils.py
@ -1,16 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
# vim:fenc=utf-8
|
|
||||||
# $File: utils.py
|
|
||||||
# $Date: Fri Nov 17 10:2407 2017 +0800
|
|
||||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
|
||||||
#
|
|
||||||
|
|
||||||
WHITE = -1
|
|
||||||
BLACK = +1
|
|
||||||
EMPTY = 0
|
|
||||||
|
|
||||||
PASS = (0,0)
|
|
||||||
RESIGN = "resign"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -25,4 +25,4 @@ class rollout_policy(evaluator):
|
|||||||
action = np.random.randint(0, self.action_num)
|
action = np.random.randint(0, self.action_num)
|
||||||
state, reward = self.env.step_forward(state, action)
|
state, reward = self.env.step_forward(state, action)
|
||||||
total_reward += reward
|
total_reward += reward
|
||||||
return total_reward
|
return np.ones([self.action_num])/self.action_num, total_reward
|
||||||
|
|||||||
@ -5,14 +5,29 @@ import time
|
|||||||
c_puct = 5
|
c_puct = 5
|
||||||
|
|
||||||
|
|
||||||
|
def list2tuple(list):
|
||||||
|
try:
|
||||||
|
return tuple(list2tuple(sub) for sub in list)
|
||||||
|
except TypeError:
|
||||||
|
return list
|
||||||
|
|
||||||
|
|
||||||
|
def tuple2list(tuple):
|
||||||
|
try:
|
||||||
|
return list(tuple2list(sub) for sub in tuple)
|
||||||
|
except TypeError:
|
||||||
|
return tuple
|
||||||
|
|
||||||
|
|
||||||
class MCTSNode(object):
|
class MCTSNode(object):
|
||||||
def __init__(self, parent, action, state, action_num, prior):
|
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.action = action
|
self.action = action
|
||||||
self.children = {}
|
self.children = {}
|
||||||
self.state = state
|
self.state = state
|
||||||
self.action_num = action_num
|
self.action_num = action_num
|
||||||
self.prior = prior
|
self.prior = prior
|
||||||
|
self.inverse = inverse
|
||||||
|
|
||||||
def selection(self, simulator):
|
def selection(self, simulator):
|
||||||
raise NotImplementedError("Need to implement function selection")
|
raise NotImplementedError("Need to implement function selection")
|
||||||
@ -20,13 +35,10 @@ class MCTSNode(object):
|
|||||||
def backpropagation(self, action):
|
def backpropagation(self, action):
|
||||||
raise NotImplementedError("Need to implement function backpropagation")
|
raise NotImplementedError("Need to implement function backpropagation")
|
||||||
|
|
||||||
def simulation(self, state, evaluator):
|
|
||||||
raise NotImplementedError("Need to implement function simulation")
|
|
||||||
|
|
||||||
|
|
||||||
class UCTNode(MCTSNode):
|
class UCTNode(MCTSNode):
|
||||||
def __init__(self, parent, action, state, action_num, prior):
|
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior)
|
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||||
self.Q = np.zeros([action_num])
|
self.Q = np.zeros([action_num])
|
||||||
self.W = np.zeros([action_num])
|
self.W = np.zeros([action_num])
|
||||||
self.N = np.zeros([action_num])
|
self.N = np.zeros([action_num])
|
||||||
@ -49,16 +61,15 @@ class UCTNode(MCTSNode):
|
|||||||
self.Q[i] = (self.W[i] + 0.) / self.N[i]
|
self.Q[i] = (self.W[i] + 0.) / self.N[i]
|
||||||
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1.)
|
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1.)
|
||||||
if self.parent is not None:
|
if self.parent is not None:
|
||||||
|
if self.inverse:
|
||||||
|
self.parent.backpropagation(-self.children[action].reward)
|
||||||
|
else:
|
||||||
self.parent.backpropagation(self.children[action].reward)
|
self.parent.backpropagation(self.children[action].reward)
|
||||||
|
|
||||||
def simulation(self, evaluator, state):
|
|
||||||
value = evaluator(state)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class TSNode(MCTSNode):
|
class TSNode(MCTSNode):
|
||||||
def __init__(self, parent, action, state, action_num, prior, method="Gaussian"):
|
def __init__(self, parent, action, state, action_num, prior, method="Gaussian", inverse=False):
|
||||||
super(TSNode, self).__init__(parent, action, state, action_num, prior)
|
super(TSNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||||
if method == "Beta":
|
if method == "Beta":
|
||||||
self.alpha = np.ones([action_num])
|
self.alpha = np.ones([action_num])
|
||||||
self.beta = np.ones([action_num])
|
self.beta = np.ones([action_num])
|
||||||
@ -73,10 +84,27 @@ class ActionNode:
|
|||||||
self.action = action
|
self.action = action
|
||||||
self.children = {}
|
self.children = {}
|
||||||
self.next_state = None
|
self.next_state = None
|
||||||
|
self.origin_state = None
|
||||||
|
self.state_type = None
|
||||||
self.reward = 0
|
self.reward = 0
|
||||||
|
|
||||||
|
def type_conversion_to_tuple(self):
|
||||||
|
if type(self.next_state) is np.ndarray:
|
||||||
|
self.next_state = self.next_state.tolist()
|
||||||
|
if type(self.next_state) is list:
|
||||||
|
self.next_state = list2tuple(self.next_state)
|
||||||
|
|
||||||
|
def type_conversion_to_origin(self):
|
||||||
|
if self.state_type is np.ndarray:
|
||||||
|
self.next_state = np.array(self.next_state)
|
||||||
|
if self.state_type is list:
|
||||||
|
self.next_state = tuple2list(self.next_state)
|
||||||
|
|
||||||
def selection(self, simulator):
|
def selection(self, simulator):
|
||||||
self.next_state, self.reward = simulator.step_forward(self.parent.state, self.action)
|
self.next_state, self.reward = simulator.step_forward(self.parent.state, self.action)
|
||||||
|
self.origin_state = self.next_state
|
||||||
|
self.state_type = type(self.next_state)
|
||||||
|
self.type_conversion_to_tuple()
|
||||||
if self.next_state is not None:
|
if self.next_state is not None:
|
||||||
if self.next_state in self.children.keys():
|
if self.next_state in self.children.keys():
|
||||||
return self.children[self.next_state].selection(simulator)
|
return self.children[self.next_state].selection(simulator)
|
||||||
@ -85,14 +113,15 @@ class ActionNode:
|
|||||||
else:
|
else:
|
||||||
return self.parent, self.action
|
return self.parent, self.action
|
||||||
|
|
||||||
def expansion(self, action_num):
|
def expansion(self, evaluator, action_num):
|
||||||
# TODO: Let users/evaluator give the prior
|
# TODO: Let users/evaluator give the prior
|
||||||
if self.next_state is not None:
|
if self.next_state is not None:
|
||||||
prior = np.ones([action_num]) / action_num
|
prior, value = evaluator(self.next_state)
|
||||||
self.children[self.next_state] = UCTNode(self, self.action, self.next_state, action_num, prior)
|
self.children[self.next_state] = UCTNode(self, self.action, self.origin_state, action_num, prior,
|
||||||
return True
|
self.parent.inverse)
|
||||||
|
return value
|
||||||
else:
|
else:
|
||||||
return False
|
return 0
|
||||||
|
|
||||||
def backpropagation(self, value):
|
def backpropagation(self, value):
|
||||||
self.reward += value
|
self.reward += value
|
||||||
@ -100,14 +129,16 @@ class ActionNode:
|
|||||||
|
|
||||||
|
|
||||||
class MCTS:
|
class MCTS:
|
||||||
def __init__(self, simulator, evaluator, root, action_num, prior, method="UCT", max_step=None, max_time=None):
|
def __init__(self, simulator, evaluator, root, action_num, prior, method="UCT", inverse=False, max_step=None,
|
||||||
|
max_time=None):
|
||||||
self.simulator = simulator
|
self.simulator = simulator
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
self.action_num = action_num
|
self.action_num = action_num
|
||||||
if method == "UCT":
|
if method == "UCT":
|
||||||
self.root = UCTNode(None, None, root, action_num, prior)
|
self.root = UCTNode(None, None, root, action_num, prior, inverse)
|
||||||
if method == "TS":
|
if method == "TS":
|
||||||
self.root = TSNode(None, None, root, action_num, prior)
|
self.root = TSNode(None, None, root, action_num, prior, inverse=inverse)
|
||||||
|
self.inverse = inverse
|
||||||
if max_step is not None:
|
if max_step is not None:
|
||||||
self.step = 0
|
self.step = 0
|
||||||
self.max_step = max_step
|
self.max_step = max_step
|
||||||
@ -118,23 +149,15 @@ class MCTS:
|
|||||||
raise ValueError("Need a stop criteria!")
|
raise ValueError("Need a stop criteria!")
|
||||||
while (max_step is not None and self.step < self.max_step or max_step is None) \
|
while (max_step is not None and self.step < self.max_step or max_step is None) \
|
||||||
and (max_time is not None and time.time() - self.start_time < self.max_time or max_time is None):
|
and (max_time is not None and time.time() - self.start_time < self.max_time or max_time is None):
|
||||||
print("Q={}".format(self.root.Q))
|
|
||||||
print("N={}".format(self.root.N))
|
|
||||||
print("W={}".format(self.root.W))
|
|
||||||
print("UCB={}".format(self.root.ucb))
|
|
||||||
print("\n")
|
|
||||||
self.expand()
|
self.expand()
|
||||||
if max_step is not None:
|
if max_step is not None:
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
def expand(self):
|
def expand(self):
|
||||||
node, new_action = self.root.selection(self.simulator)
|
node, new_action = self.root.selection(self.simulator)
|
||||||
success = node.children[new_action].expansion(self.action_num)
|
value = node.children[new_action].expansion(self.evaluator, self.action_num)
|
||||||
if success:
|
print("Value:{}".format(value))
|
||||||
value = node.simulation(self.evaluator, node.children[new_action].next_state)
|
|
||||||
node.children[new_action].backpropagation(value + 0.)
|
node.children[new_action].backpropagation(value + 0.)
|
||||||
else:
|
|
||||||
node.children[new_action].backpropagation(0.)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -24,7 +24,7 @@ class TestEnv:
|
|||||||
else:
|
else:
|
||||||
num = state[0] + 2 ** state[1] * action
|
num = state[0] + 2 ** state[1] * action
|
||||||
step = state[1] + 1
|
step = state[1] + 1
|
||||||
new_state = (num, step)
|
new_state = [num, step]
|
||||||
if step == self.max_step:
|
if step == self.max_step:
|
||||||
reward = int(np.random.uniform() < self.reward[num])
|
reward = int(np.random.uniform() < self.reward[num])
|
||||||
else:
|
else:
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user