implement the training process

This commit is contained in:
rtz19970824 2017-12-21 23:30:24 +08:00
parent 2dad8e4020
commit c11eccbc90
4 changed files with 114 additions and 23 deletions

1
.gitignore vendored
View File

@ -8,3 +8,4 @@ checkpoints
checkpoints_origin checkpoints_origin
*.json *.json
.DS_Store .DS_Store
data

View File

@ -60,7 +60,7 @@ class Game:
def think(self, latest_boards, color): def think(self, latest_boards, color):
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], self.size ** 2 + 1, inverse=True) mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], self.size ** 2 + 1, inverse=True)
mcts.search(max_step=1) mcts.search(max_step=20)
temp = 1 temp = 1
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp) prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
choice = np.random.choice(self.size ** 2 + 1, 1, p=prob).tolist()[0] choice = np.random.choice(self.size ** 2 + 1, 1, p=prob).tolist()[0]

View File

@ -2,6 +2,7 @@ import os
import time import time
import sys import sys
import cPickle import cPickle
from collections import deque
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -71,6 +72,13 @@ def value_head(input, is_training):
return h return h
class Data(object):
def __init__(self):
self.boards = []
self.probs = []
self.winner = 0
class ResNet(object): class ResNet(object):
def __init__(self, board_size, action_num, history_length=1, residual_block_num=20, checkpoint_path=None): def __init__(self, board_size, action_num, history_length=1, residual_block_num=20, checkpoint_path=None):
""" """
@ -85,11 +93,18 @@ class ResNet(object):
self.board_size = board_size self.board_size = board_size
self.action_num = action_num self.action_num = action_num
self.history_length = history_length self.history_length = history_length
self.checkpoint_path = checkpoint_path
self.x = tf.placeholder(tf.float32, shape=[None, self.board_size, self.board_size, 2 * self.history_length + 1]) self.x = tf.placeholder(tf.float32, shape=[None, self.board_size, self.board_size, 2 * self.history_length + 1])
self.is_training = tf.placeholder(tf.bool, shape=[]) self.is_training = tf.placeholder(tf.bool, shape=[])
self.z = tf.placeholder(tf.float32, shape=[None, 1]) self.z = tf.placeholder(tf.float32, shape=[None, 1])
self.pi = tf.placeholder(tf.float32, shape=[None, self.action_num]) self.pi = tf.placeholder(tf.float32, shape=[None, self.action_num])
self._build_network(residual_block_num, checkpoint_path) self._build_network(residual_block_num, self.checkpoint_path)
# training hyper-parameters:
self.window_length = 1000
self.save_freq = 1000
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
'winner': deque(maxlen=self.window_length)}
def _build_network(self, residual_block_num, checkpoint_path): def _build_network(self, residual_block_num, checkpoint_path):
""" """
@ -118,7 +133,7 @@ class ResNet(object):
with tf.control_dependencies(self.update_ops): with tf.control_dependencies(self.update_ops):
self.train_op = tf.train.AdamOptimizer(1e-4).minimize(self.total_loss) self.train_op = tf.train.AdamOptimizer(1e-4).minimize(self.total_loss)
self.var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) self.var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
self.saver = tf.train.Saver(max_to_keep=10, var_list=self.var_list) self.saver = tf.train.Saver(var_list=self.var_list)
self.sess = multi_gpu.create_session() self.sess = multi_gpu.create_session()
self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.global_variables_initializer())
if checkpoint_path is not None: if checkpoint_path is not None:
@ -166,21 +181,90 @@ class ResNet(object):
state[0, :, :, 2 * self.history_length] = np.zeros([self.board_size, self.board_size]) state[0, :, :, 2 * self.history_length] = np.zeros([self.board_size, self.board_size])
return state return state
#TODO: design the interface between the environment and training # TODO: design the interface between the environment and training
def train(self, mode='memory', *args, **kwargs): def train(self, mode='memory', *args, **kwargs):
if mode == 'memory': if mode == 'memory':
pass pass
if mode == 'file': if mode == 'file':
self.train_with_file(data_path=kwargs['data_path'], checkpoint_path=kwargs['checkpoint_path']) self._train_with_file(data_path=kwargs['data_path'], batch_size=kwargs['batch_size'],
checkpoint_path=kwargs['checkpoint_path'])
def train_with_file(self, data_path, checkpoint_path): def _train_with_file(self, data_path, batch_size, checkpoint_path):
# check if the path is valid
if not os.path.exists(data_path): if not os.path.exists(data_path):
raise ValueError("{} doesn't exist".format(data_path)) raise ValueError("{} doesn't exist".format(data_path))
self.checkpoint_path = checkpoint_path
if not os.path.exists(self.checkpoint_path):
os.mkdir(self.checkpoint_path)
file_list = os.listdir(data_path) new_file_list = []
if file_list <= 50: all_file_list = []
time.sleep(1) training_data = {}
else: iters = 0
file_list.sort(key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir( while True:
data_path + file) else 0) new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
all_file_list = os.listdir(data_path)
new_file_list.sort(
key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir(data_path + file) else 0)
if new_file_list:
for file in new_file_list:
states, probs, winner = self._file_to_training_data(data_path + file)
assert states.shape[0] == probs.shape[0]
assert states.shape[0] == winner.shape[0]
self.training_data['states'].append(states)
self.training_data['probs'].append(probs)
self.training_data['winner'].append(winner)
training_data['states'] = np.concatenate(self.training_data['states'], axis=0)
training_data['probs'] = np.concatenate(self.training_data['probs'], axis=0)
training_data['winner'] = np.concatenate(self.training_data['winner'], axis=0)
if len(self.training_data['states']) != self.window_length:
continue
else:
data_num = training_data['states'].shape[0]
index = np.arange(data_num)
np.random.shuffle(index)
start_time = time.time()
value_loss, policy_loss, reg, _ = self.sess.run(
[self.value_loss, self.policy_loss, self.reg, self.train_op],
feed_dict={self.x: training_data['states'][index[:batch_size]],
self.z: training_data['winner'][index[:batch_size]],
self.pi: training_data['probs'][index[:batch_size]],
self.is_training: True})
print("Iteration: {}, Time: {}, Value Loss: {}, Policy Loss: {}, Reg: {}".format(iters,
time.time() - start_time,
value_loss,
policy_loss, reg))
iters += 1
if iters % self.save_freq == 0:
save_path = "Iteration{}.ckpt".format(iters)
self.saver.save(self.sess, self.checkpoint_path + save_path)
def _file_to_training_data(self, file_name):
with open(file_name, 'r') as file:
data = cPickle.load(file)
history = deque(maxlen=self.history_length)
states = []
probs = []
winner = []
for _ in range(self.history_length):
# Note that 0 is specified, need a more general way like config
history.append([0] * self.board_size ** 2)
# Still, +1 is specified
color = +1
for [board, prob] in zip(data.boards, data.probs):
history.append(board)
states.append(self._history2state(history, color))
probs.append(np.array(prob).reshape(1, self.board_size ** 2 + 1))
winner.append(np.array(data.winner).reshape(1, 1))
color *= -1
states = np.concatenate(states, axis=0)
probs = np.concatenate(probs, axis=0)
winner = np.concatenate(winner, axis=0)
return states, probs, winner
if __name__=="__main__":
model = ResNet(board_size=9, action_num=82)
model.train("file", data_path="./data/", batch_size=128, checkpoint_path="./checkpoint/")

View File

@ -76,6 +76,7 @@ if __name__ == '__main__':
color = ['b', 'w'] color = ['b', 'w']
pattern = "[A-Z]{1}[0-9]{1}" pattern = "[A-Z]{1}[0-9]{1}"
space = re.compile("\s+")
size = 9 size = 9
show = ['.', 'X', 'O'] show = ['.', 'X', 'O']
@ -83,12 +84,20 @@ if __name__ == '__main__':
game_num = 0 game_num = 0
try: try:
while True: while True:
start_time = time.time()
num = 0 num = 0
pass_flag = [False, False] pass_flag = [False, False]
print("Start game {}".format(game_num)) print("Start game {}".format(game_num))
# end the game if both palyer chose to pass, or play too much turns # end the game if both palyer chose to pass, or play too much turns
while not (pass_flag[0] and pass_flag[1]) and num < size ** 2 * 2: while not (pass_flag[0] and pass_flag[1]) and num < size ** 2 * 2:
turn = num % 2 turn = num % 2
board = player[turn].run_cmd(str(num) + ' show_board')
board = eval(board[board.index('['):board.index(']') + 1])
for i in range(size):
for j in range(size):
print show[board[i * size + j]] + " ",
print "\n",
data.boards.append(board)
move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn] + '\n') move = player[turn].run_cmd(str(num) + ' genmove ' + color[turn] + '\n')
print role[turn] + " : " + str(move), print role[turn] + " : " + str(move),
num += 1 num += 1
@ -102,21 +111,18 @@ if __name__ == '__main__':
play_or_pass = ' PASS' play_or_pass = ' PASS'
pass_flag[turn] = True pass_flag[turn] = True
result = player[1 - turn].run_cmd(str(num) + ' play ' + color[turn] + ' ' + play_or_pass + '\n') result = player[1 - turn].run_cmd(str(num) + ' play ' + color[turn] + ' ' + play_or_pass + '\n')
board = player[turn].run_cmd(str(num) + ' show_board')
board = eval(board[board.index('['):board.index(']') + 1])
for i in range(size):
for j in range(size):
print show[board[i * size + j]] + " ",
print "\n",
data.boards.append(board)
prob = player[turn].run_cmd(str(num) + ' get_prob') prob = player[turn].run_cmd(str(num) + ' get_prob')
prob = space.sub(',', prob[prob.index('['):prob.index(']') + 1])
prob = prob.replace('[,', '[')
prob = prob.replace('],', ']')
prob = eval(prob)
data.probs.append(prob) data.probs.append(prob)
score = player[turn].run_cmd(str(num) + ' get_score') score = player[turn].run_cmd(str(num) + ' get_score')
print "Finished : ", score.split(" ")[1] print "Finished : ", score.split(" ")[1]
# TODO: generalize the player # TODO: generalize the player
if score > 0: if eval(score.split(" ")[1]) > 0:
data.winner = 1 data.winner = 1
if score < 0: if eval(score.split(" ")[1]) < 0:
data.winner = -1 data.winner = -1
player[0].run_cmd(str(num) + ' clear_board') player[0].run_cmd(str(num) + ' clear_board')
player[1].run_cmd(str(num) + ' clear_board') player[1].run_cmd(str(num) + ' clear_board')
@ -127,12 +133,12 @@ if __name__ == '__main__':
file_list.sort(key=lambda file: os.path.getmtime(args.result_path + file) if not os.path.isdir( file_list.sort(key=lambda file: os.path.getmtime(args.result_path + file) if not os.path.isdir(
args.result_path + file) else 0) args.result_path + file) else 0)
data_num = eval(file_list[-1][:-4]) + 1 data_num = eval(file_list[-1][:-4]) + 1
print(file_list)
with open("./data/" + str(data_num) + ".pkl", "w") as file: with open("./data/" + str(data_num) + ".pkl", "w") as file:
picklestring = cPickle.dump(data, file) picklestring = cPickle.dump(data, file)
data.reset() data.reset()
game_num += 1 game_num += 1
except KeyboardInterrupt: print("Time {}".format(time.time()-start_time))
except Exception:
subprocess.call(["kill", "-9", str(agent_v0.pid)]) subprocess.call(["kill", "-9", str(agent_v0.pid)])
subprocess.call(["kill", "-9", str(agent_v1.pid)]) subprocess.call(["kill", "-9", str(agent_v1.pid)])
print "Kill all player, finish all game." print "Kill all player, finish all game."