implement a stochastic sample training method

This commit is contained in:
rtz19970824 2017-12-22 17:16:44 +08:00
parent ed96268454
commit 67ba76a04d
2 changed files with 27 additions and 21 deletions

View File

@ -31,7 +31,7 @@ class Game:
self.latest_boards = deque(maxlen=8)
for _ in range(8):
self.latest_boards.append(self.board)
self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8)
self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8, checkpoint_path=checkpoint_path)
# self.evaluator = lambda state: self.sess.run([tf.nn.softmax(self.net.p), self.net.v],
# feed_dict={self.net.x: state, self.net.is_training: False})
self.game_engine = go.Go(size=self.size, komi=self.komi)
@ -96,7 +96,7 @@ class Game:
sys.stdout.flush()
if __name__ == "__main__":
g = Game()
g = Game(checkpoint_path='./checkpoints/')
g.show_board()
g.think_play_move(1)
#file = open("debug.txt", "a")

View File

@ -1,5 +1,6 @@
import os
import time
import random
import sys
import cPickle
from collections import deque
@ -104,7 +105,7 @@ class ResNet(object):
self.window_length = 7000
self.save_freq = 5000
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
'winner': deque(maxlen=self.window_length)}
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
def _build_network(self, residual_block_num, checkpoint_path):
"""
@ -199,15 +200,15 @@ class ResNet(object):
new_file_list = []
all_file_list = []
training_data = {}
training_data = {'states': [], 'probs': [], 'winner': []}
iters = 0
while True:
new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
if new_file_list:
while new_file_list:
all_file_list = os.listdir(data_path)
new_file_list.sort(
key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir(data_path + file) else 0)
if new_file_list:
for file in new_file_list:
states, probs, winner = self._file_to_training_data(data_path + file)
assert states.shape[0] == probs.shape[0]
@ -215,32 +216,36 @@ class ResNet(object):
self.training_data['states'].append(states)
self.training_data['probs'].append(probs)
self.training_data['winner'].append(winner)
if len(self.training_data['states']) == self.window_length:
training_data['states'] = np.concatenate(self.training_data['states'], axis=0)
training_data['probs'] = np.concatenate(self.training_data['probs'], axis=0)
training_data['winner'] = np.concatenate(self.training_data['winner'], axis=0)
self.training_data['length'].append(states.shape[0])
new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
if len(self.training_data['states']) != self.window_length:
continue
else:
data_num = training_data['states'].shape[0]
index = np.arange(data_num)
np.random.shuffle(index)
start_time = time.time()
for i in range(batch_size):
game_num = random.randint(0, self.window_length-1)
state_num = random.randint(0, self.training_data['length'][game_num]-1)
training_data['states'].append(np.expand_dims(self.training_data['states'][game_num][state_num], 0))
training_data['probs'].append(np.expand_dims(self.training_data['probs'][game_num][state_num], 0))
training_data['winner'].append(np.expand_dims(self.training_data['winner'][game_num][state_num], 0))
value_loss, policy_loss, reg, _ = self.sess.run(
[self.value_loss, self.policy_loss, self.reg, self.train_op],
feed_dict={self.x: training_data['states'][index[:batch_size]],
self.z: training_data['winner'][index[:batch_size]],
self.pi: training_data['probs'][index[:batch_size]],
feed_dict={self.x: np.concatenate(training_data['states'], axis=0),
self.z: np.concatenate(training_data['winner'], axis=0),
self.pi: np.concatenate(training_data['probs'], axis=0),
self.is_training: True})
print("Iteration: {}, Time: {}, Value Loss: {}, Policy Loss: {}, Reg: {}".format(iters,
time.time() - start_time,
value_loss,
policy_loss, reg))
iters += 1
if iters % self.save_freq == 0:
save_path = "Iteration{}.ckpt".format(iters)
self.saver.save(self.sess, self.checkpoint_path + save_path)
for key in training_data.keys():
training_data[key] = []
iters += 1
def _file_to_training_data(self, file_name):
read = False
@ -250,6 +255,7 @@ class ResNet(object):
file.seek(0)
data = cPickle.load(file)
read = True
print("{} Loaded!".format(file_name))
except Exception as e:
print(e)
time.sleep(1)
@ -275,6 +281,6 @@ class ResNet(object):
return states, probs, winner
if __name__=="__main__":
model = ResNet(board_size=9, action_num=82)
if __name__ == "__main__":
model = ResNet(board_size=9, action_num=82, history_length=8)
model.train("file", data_path="./data/", batch_size=128, checkpoint_path="./checkpoint/")