self play
This commit is contained in:
parent
b687241a7d
commit
906ced84a3
@ -190,7 +190,7 @@ class Game:
|
|||||||
self.executor = Executor(game=self)
|
self.executor = Executor(game=self)
|
||||||
self.history = []
|
self.history = []
|
||||||
self.past = deque(maxlen=8)
|
self.past = deque(maxlen=8)
|
||||||
for i in range(8):
|
for _ in range(8):
|
||||||
self.past.append(self.board)
|
self.past.append(self.board)
|
||||||
|
|
||||||
def _flatten(self, vertex):
|
def _flatten(self, vertex):
|
||||||
@ -205,6 +205,9 @@ class Game:
|
|||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.board = [utils.EMPTY] * (self.size * self.size)
|
self.board = [utils.EMPTY] * (self.size * self.size)
|
||||||
|
self.history = []
|
||||||
|
for _ in range(8):
|
||||||
|
self.past.append(self.board)
|
||||||
|
|
||||||
def set_size(self, n):
|
def set_size(self, n):
|
||||||
self.size = n
|
self.size = n
|
||||||
@ -225,7 +228,7 @@ class Game:
|
|||||||
def gen_move(self, color):
|
def gen_move(self, color):
|
||||||
# move = self.strategy.gen_move(color)
|
# move = self.strategy.gen_move(color)
|
||||||
# return move
|
# return move
|
||||||
move = self.strategy.gen_move(self.past, color)
|
move, self.prob = self.strategy.gen_move(self.past, color)
|
||||||
self.do_move(color, move)
|
self.do_move(color, move)
|
||||||
return move
|
return move
|
||||||
|
|
||||||
|
@ -59,11 +59,12 @@ class Network(object):
|
|||||||
self.build_network()
|
self.build_network()
|
||||||
|
|
||||||
def build_network(self):
|
def build_network(self):
|
||||||
h = layers.conv2d(self.x, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,
|
h = layers.conv2d(self.x, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu,
|
||||||
|
normalizer_fn=layers.batch_norm,
|
||||||
normalizer_params={'is_training': self.is_training,
|
normalizer_params={'is_training': self.is_training,
|
||||||
'updates_collections': tf.GraphKeys.UPDATE_OPS},
|
'updates_collections': tf.GraphKeys.UPDATE_OPS},
|
||||||
weights_regularizer=layers.l2_regularizer(1e-4))
|
weights_regularizer=layers.l2_regularizer(1e-4))
|
||||||
for i in range(19):
|
for i in range(4):
|
||||||
h = residual_block(h, self.is_training)
|
h = residual_block(h, self.is_training)
|
||||||
self.v = value_heads(h, self.is_training)
|
self.v = value_heads(h, self.is_training)
|
||||||
self.p = policy_heads(h, self.is_training)
|
self.p = policy_heads(h, self.is_training)
|
||||||
@ -115,9 +116,9 @@ class Network(object):
|
|||||||
feed_dict={self.x: boards[
|
feed_dict={self.x: boards[
|
||||||
index[iter * batch_size:(iter + 1) * batch_size]],
|
index[iter * batch_size:(iter + 1) * batch_size]],
|
||||||
self.z: wins[index[
|
self.z: wins[index[
|
||||||
iter * batch_size:(iter + 1) * batch_size]],
|
iter * batch_size:(iter + 1) * batch_size]],
|
||||||
self.pi: ps[index[
|
self.pi: ps[index[
|
||||||
iter * batch_size:(iter + 1) * batch_size]],
|
iter * batch_size:(iter + 1) * batch_size]],
|
||||||
self.is_training: True})
|
self.is_training: True})
|
||||||
value_losses.append(lv)
|
value_losses.append(lv)
|
||||||
policy_losses.append(lp)
|
policy_losses.append(lp)
|
||||||
@ -137,53 +138,53 @@ class Network(object):
|
|||||||
del data, boards, wins, ps
|
del data, boards, wins, ps
|
||||||
|
|
||||||
|
|
||||||
# def forward(call_number):
|
# def forward(call_number):
|
||||||
# # checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints"
|
# # checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints"
|
||||||
# checkpoint_path = "/home/jialian/stuGo/tianshou/stuGo/checkpoints/"
|
# checkpoint_path = "/home/jialian/stuGo/tianshou/stuGo/checkpoints/"
|
||||||
# board_file = np.genfromtxt("/home/jialian/stuGo/tianshou/leela-zero/src/mcts_nn_files/board_" + call_number,
|
# board_file = np.genfromtxt("/home/jialian/stuGo/tianshou/leela-zero/src/mcts_nn_files/board_" + call_number,
|
||||||
# dtype='str');
|
# dtype='str');
|
||||||
# human_board = np.zeros((17, 19, 19))
|
# human_board = np.zeros((17, 19, 19))
|
||||||
#
|
#
|
||||||
# # TODO : is it ok to ignore the last channel?
|
# # TODO : is it ok to ignore the last channel?
|
||||||
# for i in range(17):
|
# for i in range(17):
|
||||||
# human_board[i] = np.array(list(board_file[i])).reshape(19, 19)
|
# human_board[i] = np.array(list(board_file[i])).reshape(19, 19)
|
||||||
# # print("============================")
|
# # print("============================")
|
||||||
# # print("human board sum : " + str(np.sum(human_board[-1])))
|
# # print("human board sum : " + str(np.sum(human_board[-1])))
|
||||||
# # print("============================")
|
# # print("============================")
|
||||||
# # print(human_board)
|
# # print(human_board)
|
||||||
# # print("============================")
|
# # print("============================")
|
||||||
# # rint(human_board)
|
# # rint(human_board)
|
||||||
# feed_board = human_board.transpose(1, 2, 0).reshape(1, 19, 19, 17)
|
# feed_board = human_board.transpose(1, 2, 0).reshape(1, 19, 19, 17)
|
||||||
# # print(feed_board[:,:,:,-1])
|
# # print(feed_board[:,:,:,-1])
|
||||||
# # print(feed_board.shape)
|
# # print(feed_board.shape)
|
||||||
#
|
#
|
||||||
# # npz_board = np.load("/home/yama/rl/tianshou/AlphaGo/data/7f83928932f64a79bc1efdea268698ae.npz")
|
# # npz_board = np.load("/home/yama/rl/tianshou/AlphaGo/data/7f83928932f64a79bc1efdea268698ae.npz")
|
||||||
# # print(npz_board["boards"].shape)
|
# # print(npz_board["boards"].shape)
|
||||||
# # feed_board = npz_board["boards"][10].reshape(-1, 19, 19, 17)
|
# # feed_board = npz_board["boards"][10].reshape(-1, 19, 19, 17)
|
||||||
# ##print(feed_board)
|
# ##print(feed_board)
|
||||||
# # show_board = feed_board[0].transpose(2, 0, 1)
|
# # show_board = feed_board[0].transpose(2, 0, 1)
|
||||||
# # print("board shape : ", show_board.shape)
|
# # print("board shape : ", show_board.shape)
|
||||||
# # print(show_board)
|
# # print(show_board)
|
||||||
#
|
#
|
||||||
# itflag = False
|
# itflag = False
|
||||||
# with multi_gpu.create_session() as sess:
|
# with multi_gpu.create_session() as sess:
|
||||||
# sess.run(tf.global_variables_initializer())
|
# sess.run(tf.global_variables_initializer())
|
||||||
# ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
|
# ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
|
||||||
# if ckpt_file is not None:
|
# if ckpt_file is not None:
|
||||||
# # print('Restoring model from {}...'.format(ckpt_file))
|
# # print('Restoring model from {}...'.format(ckpt_file))
|
||||||
# saver.restore(sess, ckpt_file)
|
# saver.restore(sess, ckpt_file)
|
||||||
# else:
|
# else:
|
||||||
# raise ValueError("No model loaded")
|
# raise ValueError("No model loaded")
|
||||||
# res = sess.run([tf.nn.softmax(p), v], feed_dict={x: feed_board, is_training: itflag})
|
# res = sess.run([tf.nn.softmax(p), v], feed_dict={x: feed_board, is_training: itflag})
|
||||||
# # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
|
# # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
|
||||||
# # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
|
# # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
|
||||||
# # print(np.argmax(res[0]))
|
# # print(np.argmax(res[0]))
|
||||||
# np.savetxt(sys.stdout, res[0][0], fmt="%.6f", newline=" ")
|
# np.savetxt(sys.stdout, res[0][0], fmt="%.6f", newline=" ")
|
||||||
# np.savetxt(sys.stdout, res[1][0], fmt="%.6f", newline=" ")
|
# np.savetxt(sys.stdout, res[1][0], fmt="%.6f", newline=" ")
|
||||||
# pv_file = "/home/jialian/stuGotianshou/leela-zero/src/mcts_nn_files/policy_value"
|
# pv_file = "/home/jialian/stuGotianshou/leela-zero/src/mcts_nn_files/policy_value"
|
||||||
# np.savetxt(pv_file, np.concatenate((res[0][0], res[1][0])), fmt="%.6f", newline=" ")
|
# np.savetxt(pv_file, np.concatenate((res[0][0], res[1][0])), fmt="%.6f", newline=" ")
|
||||||
# # np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ")
|
# # np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ")
|
||||||
# return res
|
# return res
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
# checkpoint_path = "/home/tongzheng/tianshou/AlphaGo/checkpoints/"
|
# checkpoint_path = "/home/tongzheng/tianshou/AlphaGo/checkpoints/"
|
||||||
|
@ -1,40 +1,98 @@
|
|||||||
from game import Game
|
from game import Game
|
||||||
from engine import GTPEngine
|
from engine import GTPEngine
|
||||||
import re
|
import re
|
||||||
|
import numpy as np
|
||||||
|
from collections import deque
|
||||||
|
import utils
|
||||||
|
import argparse
|
||||||
|
|
||||||
g = Game()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--result_path', type=str, default='./part1')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
game = Game()
|
||||||
|
engine = GTPEngine(game_obj=game)
|
||||||
|
history = deque(maxlen=8)
|
||||||
|
for i in range(8):
|
||||||
|
history.append(game.board)
|
||||||
|
state = []
|
||||||
|
prob = []
|
||||||
|
winner = []
|
||||||
pattern = "[A-Z]{1}[0-9]{1}"
|
pattern = "[A-Z]{1}[0-9]{1}"
|
||||||
|
game.show_board()
|
||||||
|
|
||||||
|
|
||||||
|
def history2state(history, color):
|
||||||
|
state = np.zeros([1, game.size, game.size, 17])
|
||||||
|
for i in range(8):
|
||||||
|
state[0, :, :, i] = np.array(np.array(history[i]) == np.ones(game.size ** 2)).reshape(game.size, game.size)
|
||||||
|
state[0, :, :, i + 8] = np.array(np.array(history[i]) == -np.ones(game.size ** 2)).reshape(game.size, game.size)
|
||||||
|
if color == utils.BLACK:
|
||||||
|
state[0, :, :, 16] = np.ones([game.size, game.size])
|
||||||
|
if color == utils.WHITE:
|
||||||
|
state[0, :, :, 16] = np.zeros([game.size, game.size])
|
||||||
|
return state
|
||||||
|
|
||||||
g.show_board()
|
|
||||||
e = GTPEngine(game_obj=g)
|
|
||||||
|
|
||||||
num = 0
|
num = 0
|
||||||
|
game_num = 0
|
||||||
black_pass = False
|
black_pass = False
|
||||||
white_pass = False
|
white_pass = False
|
||||||
while not (black_pass and white_pass):
|
while True:
|
||||||
if num % 2 == 0:
|
while not (black_pass and white_pass) and num < game.size ** 2 * 2:
|
||||||
res = e.run_cmd(str(num) + " genmove BLACK")
|
if num % 2 == 0:
|
||||||
num += 1
|
color = utils.BLACK
|
||||||
# print(res)
|
new_state = history2state(history, color)
|
||||||
match = re.search(pattern, res)
|
state.append(new_state)
|
||||||
if match is not None:
|
result = engine.run_cmd(str(num) + " genmove BLACK")
|
||||||
print(match.group())
|
num += 1
|
||||||
|
match = re.search(pattern, result)
|
||||||
|
if match is not None:
|
||||||
|
print(match.group())
|
||||||
|
else:
|
||||||
|
print("pass")
|
||||||
|
if re.search("pass", result) is not None:
|
||||||
|
black_pass = True
|
||||||
|
else:
|
||||||
|
black_pass = False
|
||||||
else:
|
else:
|
||||||
print("pass")
|
color = utils.WHITE
|
||||||
if re.search("pass", res) is not None:
|
new_state = history2state(history, color)
|
||||||
black_pass = True
|
state.append(new_state)
|
||||||
else:
|
result = engine.run_cmd(str(num) + " genmove WHITE")
|
||||||
black_pass = False
|
num += 1
|
||||||
|
match = re.search(pattern, result)
|
||||||
|
if match is not None:
|
||||||
|
print(match.group())
|
||||||
|
else:
|
||||||
|
print("pass")
|
||||||
|
if re.search("pass", result) is not None:
|
||||||
|
white_pass = True
|
||||||
|
else:
|
||||||
|
white_pass = False
|
||||||
|
game.show_board()
|
||||||
|
prob.append(np.array(game.prob).reshape(-1, game.size ** 2 + 1))
|
||||||
|
print("Finished")
|
||||||
|
score = game.executor.get_score()
|
||||||
|
if score > 0:
|
||||||
|
winner = utils.BLACK
|
||||||
else:
|
else:
|
||||||
res = e.run_cmd(str(num) + " genmove WHITE")
|
winner = utils.WHITE
|
||||||
num += 1
|
state = np.concatenate(state, axis=0)
|
||||||
match = re.search(pattern, res)
|
prob = np.concatenate(prob, axis=0)
|
||||||
if match is not None:
|
winner = np.ones([num, 1]) * winner
|
||||||
print(match.group())
|
assert state.shape[0] == prob.shape[0]
|
||||||
else:
|
assert state.shape[0] == winner.shape[0]
|
||||||
print("pass")
|
np.savez(args.result_path + "/game" + game_num, state=state, prob=prob, winner=winner)
|
||||||
if re.search("pass", res) is not None:
|
state = []
|
||||||
white_pass = True
|
prob = []
|
||||||
else:
|
winner = []
|
||||||
white_pass = False
|
num = 0
|
||||||
g.show_board()
|
black_pass = False
|
||||||
|
white_pass = False
|
||||||
|
engine.run_cmd(str(num) + " clear_board")
|
||||||
|
history.clear()
|
||||||
|
for _ in range(8):
|
||||||
|
history.append(game.board)
|
||||||
|
game.show_board()
|
||||||
|
game_num += 1
|
||||||
|
@ -198,28 +198,27 @@ class GoEnv:
|
|||||||
id_ = self._flatten(vertex)
|
id_ = self._flatten(vertex)
|
||||||
if self.board[id_] == utils.EMPTY:
|
if self.board[id_] == utils.EMPTY:
|
||||||
self.board[id_] = color
|
self.board[id_] = color
|
||||||
self.history.append(copy.copy(self.board))
|
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def step_forward(self, state, action):
|
def step_forward(self, state, action):
|
||||||
if state[0, 0, 0, -1] == 1:
|
if state[0, 0, 0, -1] == 1:
|
||||||
color = 1
|
color = utils.BLACK
|
||||||
else:
|
else:
|
||||||
color = -1
|
color = utils.WHITE
|
||||||
if action == 81:
|
if action == self.size ** 2:
|
||||||
vertex = (0, 0)
|
vertex = utils.PASS
|
||||||
else:
|
else:
|
||||||
vertex = (action % 9 + 1, action / 9 + 1)
|
vertex = (action % self.size + 1, action / self.size + 1)
|
||||||
# print(vertex)
|
# print(vertex)
|
||||||
# print(self.board)
|
# print(self.board)
|
||||||
self.board = (state[:, :, :, 7] - state[:, :, :, 15]).reshape(-1).tolist()
|
self.board = (state[:, :, :, 7] - state[:, :, :, 15]).reshape(-1).tolist()
|
||||||
self.do_move(color, vertex)
|
self.do_move(color, vertex)
|
||||||
new_state = np.concatenate(
|
new_state = np.concatenate(
|
||||||
[state[:, :, :, 1:8], (np.array(self.board) == 1).reshape(1, 9, 9, 1),
|
[state[:, :, :, 1:8], (np.array(self.board) == utils.BLACK).reshape(1, self.size, self.size, 1),
|
||||||
state[:, :, :, 9:16], (np.array(self.board) == -1).reshape(1, 9, 9, 1),
|
state[:, :, :, 9:16], (np.array(self.board) == utils.WHITE).reshape(1, self.size, self.size, 1),
|
||||||
np.array(1 - state[:, :, :, -1]).reshape(1, 9, 9, 1)],
|
np.array(1 - state[:, :, :, -1]).reshape(1, self.size, self.size, 1)],
|
||||||
axis=3)
|
axis=3)
|
||||||
return new_state, 0
|
return new_state, 0
|
||||||
|
|
||||||
@ -233,26 +232,26 @@ class strategy(object):
|
|||||||
feed_dict={self.net.x: state, self.net.is_training: False})
|
feed_dict={self.net.x: state, self.net.is_training: False})
|
||||||
|
|
||||||
def data_process(self, history, color):
|
def data_process(self, history, color):
|
||||||
state = np.zeros([1, 9, 9, 17])
|
state = np.zeros([1, self.simulator.size, self.simulator.size, 17])
|
||||||
for i in range(8):
|
for i in range(8):
|
||||||
state[0, :, :, i] = np.array(np.array(history[i]) == np.ones(81)).reshape(9, 9)
|
state[0, :, :, i] = np.array(np.array(history[i]) == np.ones(self.simulator.size ** 2)).reshape(self.simulator.size, self.simulator.size)
|
||||||
state[0, :, :, i + 8] = np.array(np.array(history[i]) == -np.ones(81)).reshape(9, 9)
|
state[0, :, :, i + 8] = np.array(np.array(history[i]) == -np.ones(self.simulator.size ** 2)).reshape(self.simulator.size, self.simulator.size)
|
||||||
if color == 1:
|
if color == utils.BLACK:
|
||||||
state[0, :, :, 16] = np.ones([9, 9])
|
state[0, :, :, 16] = np.ones([self.simulator.size, self.simulator.size])
|
||||||
if color == -1:
|
if color == utils.WHITE:
|
||||||
state[0, :, :, 16] = np.zeros([9, 9])
|
state[0, :, :, 16] = np.zeros([self.simulator.size, self.simulator.size])
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def gen_move(self, history, color):
|
def gen_move(self, history, color):
|
||||||
self.simulator.history = copy.copy(history)
|
self.simulator.history = copy.copy(history)
|
||||||
self.simulator.board = copy.copy(history[-1])
|
self.simulator.board = copy.copy(history[-1])
|
||||||
state = self.data_process(self.simulator.history, color)
|
state = self.data_process(self.simulator.history, color)
|
||||||
mcts = MCTS(self.simulator, self.evaluator, state, 82, inverse=True, max_step=10)
|
mcts = MCTS(self.simulator, self.evaluator, state, self.simulator.size ** 2 + 1, inverse=True, max_step=100)
|
||||||
temp = 1
|
temp = 1
|
||||||
p = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
||||||
choice = np.random.choice(82, 1, p=p).tolist()[0]
|
choice = np.random.choice(self.simulator.size ** 2 + 1, 1, p=prob).tolist()[0]
|
||||||
if choice == 81:
|
if choice == self.simulator.size ** 2:
|
||||||
move = (0, 0)
|
move = utils.PASS
|
||||||
else:
|
else:
|
||||||
move = (choice % 9 + 1, choice / 9 + 1)
|
move = (choice % self.simulator.size + 1, choice / self.simulator.size + 1)
|
||||||
return move
|
return move, prob
|
||||||
|
14
AlphaGo/test.py
Normal file
14
AlphaGo/test.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import sys
|
||||||
|
from game import Game
|
||||||
|
from engine import GTPEngine
|
||||||
|
# import utils
|
||||||
|
|
||||||
|
game = Game()
|
||||||
|
engine = GTPEngine(game_obj=game, name='tianshou')
|
||||||
|
cmd = raw_input
|
||||||
|
|
||||||
|
while not engine.disconnect:
|
||||||
|
command = cmd()
|
||||||
|
result = engine.run_cmd(command)
|
||||||
|
sys.stdout.write(result)
|
||||||
|
sys.stdout.flush()
|
1
bin/activate
Symbolic link
1
bin/activate
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
/home/tongzheng/anaconda2/bin/activate
|
1
bin/deactivate
Symbolic link
1
bin/deactivate
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
/home/tongzheng/anaconda2/bin/deactivate
|
Loading…
x
Reference in New Issue
Block a user