implement a stochastic sample training method

This commit is contained in:
rtz19970824 2017-12-22 17:16:44 +08:00
parent ed96268454
commit 67ba76a04d
2 changed files with 27 additions and 21 deletions

View File

@ -31,7 +31,7 @@ class Game:
self.latest_boards = deque(maxlen=8) self.latest_boards = deque(maxlen=8)
for _ in range(8): for _ in range(8):
self.latest_boards.append(self.board) self.latest_boards.append(self.board)
self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8) self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8, checkpoint_path=checkpoint_path)
# self.evaluator = lambda state: self.sess.run([tf.nn.softmax(self.net.p), self.net.v], # self.evaluator = lambda state: self.sess.run([tf.nn.softmax(self.net.p), self.net.v],
# feed_dict={self.net.x: state, self.net.is_training: False}) # feed_dict={self.net.x: state, self.net.is_training: False})
self.game_engine = go.Go(size=self.size, komi=self.komi) self.game_engine = go.Go(size=self.size, komi=self.komi)
@ -96,7 +96,7 @@ class Game:
sys.stdout.flush() sys.stdout.flush()
if __name__ == "__main__": if __name__ == "__main__":
g = Game() g = Game(checkpoint_path='./checkpoints/')
g.show_board() g.show_board()
g.think_play_move(1) g.think_play_move(1)
#file = open("debug.txt", "a") #file = open("debug.txt", "a")

View File

@ -1,5 +1,6 @@
import os import os
import time import time
import random
import sys import sys
import cPickle import cPickle
from collections import deque from collections import deque
@ -104,7 +105,7 @@ class ResNet(object):
self.window_length = 7000 self.window_length = 7000
self.save_freq = 5000 self.save_freq = 5000
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length), self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
'winner': deque(maxlen=self.window_length)} 'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
def _build_network(self, residual_block_num, checkpoint_path): def _build_network(self, residual_block_num, checkpoint_path):
""" """
@ -199,15 +200,15 @@ class ResNet(object):
new_file_list = [] new_file_list = []
all_file_list = [] all_file_list = []
training_data = {} training_data = {'states': [], 'probs': [], 'winner': []}
iters = 0 iters = 0
while True: while True:
new_file_list = list(set(os.listdir(data_path)).difference(all_file_list)) new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
if new_file_list: while new_file_list:
all_file_list = os.listdir(data_path) all_file_list = os.listdir(data_path)
new_file_list.sort( new_file_list.sort(
key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir(data_path + file) else 0) key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir(data_path + file) else 0)
if new_file_list:
for file in new_file_list: for file in new_file_list:
states, probs, winner = self._file_to_training_data(data_path + file) states, probs, winner = self._file_to_training_data(data_path + file)
assert states.shape[0] == probs.shape[0] assert states.shape[0] == probs.shape[0]
@ -215,32 +216,36 @@ class ResNet(object):
self.training_data['states'].append(states) self.training_data['states'].append(states)
self.training_data['probs'].append(probs) self.training_data['probs'].append(probs)
self.training_data['winner'].append(winner) self.training_data['winner'].append(winner)
if len(self.training_data['states']) == self.window_length: self.training_data['length'].append(states.shape[0])
training_data['states'] = np.concatenate(self.training_data['states'], axis=0) new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
training_data['probs'] = np.concatenate(self.training_data['probs'], axis=0)
training_data['winner'] = np.concatenate(self.training_data['winner'], axis=0)
if len(self.training_data['states']) != self.window_length: if len(self.training_data['states']) != self.window_length:
continue continue
else: else:
data_num = training_data['states'].shape[0]
index = np.arange(data_num)
np.random.shuffle(index)
start_time = time.time() start_time = time.time()
for i in range(batch_size):
game_num = random.randint(0, self.window_length-1)
state_num = random.randint(0, self.training_data['length'][game_num]-1)
training_data['states'].append(np.expand_dims(self.training_data['states'][game_num][state_num], 0))
training_data['probs'].append(np.expand_dims(self.training_data['probs'][game_num][state_num], 0))
training_data['winner'].append(np.expand_dims(self.training_data['winner'][game_num][state_num], 0))
value_loss, policy_loss, reg, _ = self.sess.run( value_loss, policy_loss, reg, _ = self.sess.run(
[self.value_loss, self.policy_loss, self.reg, self.train_op], [self.value_loss, self.policy_loss, self.reg, self.train_op],
feed_dict={self.x: training_data['states'][index[:batch_size]], feed_dict={self.x: np.concatenate(training_data['states'], axis=0),
self.z: training_data['winner'][index[:batch_size]], self.z: np.concatenate(training_data['winner'], axis=0),
self.pi: training_data['probs'][index[:batch_size]], self.pi: np.concatenate(training_data['probs'], axis=0),
self.is_training: True}) self.is_training: True})
print("Iteration: {}, Time: {}, Value Loss: {}, Policy Loss: {}, Reg: {}".format(iters, print("Iteration: {}, Time: {}, Value Loss: {}, Policy Loss: {}, Reg: {}".format(iters,
time.time() - start_time, time.time() - start_time,
value_loss, value_loss,
policy_loss, reg)) policy_loss, reg))
iters += 1
if iters % self.save_freq == 0: if iters % self.save_freq == 0:
save_path = "Iteration{}.ckpt".format(iters) save_path = "Iteration{}.ckpt".format(iters)
self.saver.save(self.sess, self.checkpoint_path + save_path) self.saver.save(self.sess, self.checkpoint_path + save_path)
for key in training_data.keys():
training_data[key] = []
iters += 1
def _file_to_training_data(self, file_name): def _file_to_training_data(self, file_name):
read = False read = False
@ -250,6 +255,7 @@ class ResNet(object):
file.seek(0) file.seek(0)
data = cPickle.load(file) data = cPickle.load(file)
read = True read = True
print("{} Loaded!".format(file_name))
except Exception as e: except Exception as e:
print(e) print(e)
time.sleep(1) time.sleep(1)
@ -275,6 +281,6 @@ class ResNet(object):
return states, probs, winner return states, probs, winner
if __name__=="__main__": if __name__ == "__main__":
model = ResNet(board_size=9, action_num=82) model = ResNet(board_size=9, action_num=82, history_length=8)
model.train("file", data_path="./data/", batch_size=128, checkpoint_path="./checkpoint/") model.train("file", data_path="./data/", batch_size=128, checkpoint_path="./checkpoint/")