implement a stochastic sample training method
This commit is contained in:
parent
e72fd52913
commit
d8c0eae6a3
@ -31,7 +31,7 @@ class Game:
|
||||
self.latest_boards = deque(maxlen=8)
|
||||
for _ in range(8):
|
||||
self.latest_boards.append(self.board)
|
||||
self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8)
|
||||
self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8, checkpoint_path=checkpoint_path)
|
||||
# self.evaluator = lambda state: self.sess.run([tf.nn.softmax(self.net.p), self.net.v],
|
||||
# feed_dict={self.net.x: state, self.net.is_training: False})
|
||||
self.game_engine = go.Go(size=self.size, komi=self.komi)
|
||||
@ -96,7 +96,7 @@ class Game:
|
||||
sys.stdout.flush()
|
||||
|
||||
if __name__ == "__main__":
|
||||
g = Game()
|
||||
g = Game(checkpoint_path='./checkpoints/')
|
||||
g.show_board()
|
||||
g.think_play_move(1)
|
||||
#file = open("debug.txt", "a")
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import sys
|
||||
import cPickle
|
||||
from collections import deque
|
||||
@ -104,7 +105,7 @@ class ResNet(object):
|
||||
self.window_length = 7000
|
||||
self.save_freq = 5000
|
||||
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
|
||||
'winner': deque(maxlen=self.window_length)}
|
||||
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
|
||||
|
||||
def _build_network(self, residual_block_num, checkpoint_path):
|
||||
"""
|
||||
@ -199,15 +200,15 @@ class ResNet(object):
|
||||
|
||||
new_file_list = []
|
||||
all_file_list = []
|
||||
training_data = {}
|
||||
training_data = {'states': [], 'probs': [], 'winner': []}
|
||||
|
||||
iters = 0
|
||||
while True:
|
||||
new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
|
||||
if new_file_list:
|
||||
while new_file_list:
|
||||
all_file_list = os.listdir(data_path)
|
||||
new_file_list.sort(
|
||||
key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir(data_path + file) else 0)
|
||||
if new_file_list:
|
||||
for file in new_file_list:
|
||||
states, probs, winner = self._file_to_training_data(data_path + file)
|
||||
assert states.shape[0] == probs.shape[0]
|
||||
@ -215,32 +216,36 @@ class ResNet(object):
|
||||
self.training_data['states'].append(states)
|
||||
self.training_data['probs'].append(probs)
|
||||
self.training_data['winner'].append(winner)
|
||||
if len(self.training_data['states']) == self.window_length:
|
||||
training_data['states'] = np.concatenate(self.training_data['states'], axis=0)
|
||||
training_data['probs'] = np.concatenate(self.training_data['probs'], axis=0)
|
||||
training_data['winner'] = np.concatenate(self.training_data['winner'], axis=0)
|
||||
self.training_data['length'].append(states.shape[0])
|
||||
new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
|
||||
|
||||
if len(self.training_data['states']) != self.window_length:
|
||||
continue
|
||||
else:
|
||||
data_num = training_data['states'].shape[0]
|
||||
index = np.arange(data_num)
|
||||
np.random.shuffle(index)
|
||||
start_time = time.time()
|
||||
for i in range(batch_size):
|
||||
game_num = random.randint(0, self.window_length-1)
|
||||
state_num = random.randint(0, self.training_data['length'][game_num]-1)
|
||||
training_data['states'].append(np.expand_dims(self.training_data['states'][game_num][state_num], 0))
|
||||
training_data['probs'].append(np.expand_dims(self.training_data['probs'][game_num][state_num], 0))
|
||||
training_data['winner'].append(np.expand_dims(self.training_data['winner'][game_num][state_num], 0))
|
||||
value_loss, policy_loss, reg, _ = self.sess.run(
|
||||
[self.value_loss, self.policy_loss, self.reg, self.train_op],
|
||||
feed_dict={self.x: training_data['states'][index[:batch_size]],
|
||||
self.z: training_data['winner'][index[:batch_size]],
|
||||
self.pi: training_data['probs'][index[:batch_size]],
|
||||
feed_dict={self.x: np.concatenate(training_data['states'], axis=0),
|
||||
self.z: np.concatenate(training_data['winner'], axis=0),
|
||||
self.pi: np.concatenate(training_data['probs'], axis=0),
|
||||
self.is_training: True})
|
||||
|
||||
print("Iteration: {}, Time: {}, Value Loss: {}, Policy Loss: {}, Reg: {}".format(iters,
|
||||
time.time() - start_time,
|
||||
value_loss,
|
||||
policy_loss, reg))
|
||||
iters += 1
|
||||
if iters % self.save_freq == 0:
|
||||
save_path = "Iteration{}.ckpt".format(iters)
|
||||
self.saver.save(self.sess, self.checkpoint_path + save_path)
|
||||
for key in training_data.keys():
|
||||
training_data[key] = []
|
||||
iters += 1
|
||||
|
||||
def _file_to_training_data(self, file_name):
|
||||
read = False
|
||||
@ -250,6 +255,7 @@ class ResNet(object):
|
||||
file.seek(0)
|
||||
data = cPickle.load(file)
|
||||
read = True
|
||||
print("{} Loaded!".format(file_name))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
time.sleep(1)
|
||||
@ -276,5 +282,5 @@ class ResNet(object):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = ResNet(board_size=9, action_num=82)
|
||||
model = ResNet(board_size=9, action_num=82, history_length=8)
|
||||
model.train("file", data_path="./data/", batch_size=128, checkpoint_path="./checkpoint/")
|
||||
|
Loading…
x
Reference in New Issue
Block a user