implement a stochastic sample training method
This commit is contained in:
parent
ed96268454
commit
67ba76a04d
@ -31,7 +31,7 @@ class Game:
|
|||||||
self.latest_boards = deque(maxlen=8)
|
self.latest_boards = deque(maxlen=8)
|
||||||
for _ in range(8):
|
for _ in range(8):
|
||||||
self.latest_boards.append(self.board)
|
self.latest_boards.append(self.board)
|
||||||
self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8)
|
self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8, checkpoint_path=checkpoint_path)
|
||||||
# self.evaluator = lambda state: self.sess.run([tf.nn.softmax(self.net.p), self.net.v],
|
# self.evaluator = lambda state: self.sess.run([tf.nn.softmax(self.net.p), self.net.v],
|
||||||
# feed_dict={self.net.x: state, self.net.is_training: False})
|
# feed_dict={self.net.x: state, self.net.is_training: False})
|
||||||
self.game_engine = go.Go(size=self.size, komi=self.komi)
|
self.game_engine = go.Go(size=self.size, komi=self.komi)
|
||||||
@ -96,7 +96,7 @@ class Game:
|
|||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
g = Game()
|
g = Game(checkpoint_path='./checkpoints/')
|
||||||
g.show_board()
|
g.show_board()
|
||||||
g.think_play_move(1)
|
g.think_play_move(1)
|
||||||
#file = open("debug.txt", "a")
|
#file = open("debug.txt", "a")
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import random
|
||||||
import sys
|
import sys
|
||||||
import cPickle
|
import cPickle
|
||||||
from collections import deque
|
from collections import deque
|
||||||
@ -104,7 +105,7 @@ class ResNet(object):
|
|||||||
self.window_length = 7000
|
self.window_length = 7000
|
||||||
self.save_freq = 5000
|
self.save_freq = 5000
|
||||||
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
|
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
|
||||||
'winner': deque(maxlen=self.window_length)}
|
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
|
||||||
|
|
||||||
def _build_network(self, residual_block_num, checkpoint_path):
|
def _build_network(self, residual_block_num, checkpoint_path):
|
||||||
"""
|
"""
|
||||||
@ -199,15 +200,15 @@ class ResNet(object):
|
|||||||
|
|
||||||
new_file_list = []
|
new_file_list = []
|
||||||
all_file_list = []
|
all_file_list = []
|
||||||
training_data = {}
|
training_data = {'states': [], 'probs': [], 'winner': []}
|
||||||
|
|
||||||
iters = 0
|
iters = 0
|
||||||
while True:
|
while True:
|
||||||
new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
|
new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
|
||||||
if new_file_list:
|
while new_file_list:
|
||||||
all_file_list = os.listdir(data_path)
|
all_file_list = os.listdir(data_path)
|
||||||
new_file_list.sort(
|
new_file_list.sort(
|
||||||
key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir(data_path + file) else 0)
|
key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir(data_path + file) else 0)
|
||||||
if new_file_list:
|
|
||||||
for file in new_file_list:
|
for file in new_file_list:
|
||||||
states, probs, winner = self._file_to_training_data(data_path + file)
|
states, probs, winner = self._file_to_training_data(data_path + file)
|
||||||
assert states.shape[0] == probs.shape[0]
|
assert states.shape[0] == probs.shape[0]
|
||||||
@ -215,32 +216,36 @@ class ResNet(object):
|
|||||||
self.training_data['states'].append(states)
|
self.training_data['states'].append(states)
|
||||||
self.training_data['probs'].append(probs)
|
self.training_data['probs'].append(probs)
|
||||||
self.training_data['winner'].append(winner)
|
self.training_data['winner'].append(winner)
|
||||||
if len(self.training_data['states']) == self.window_length:
|
self.training_data['length'].append(states.shape[0])
|
||||||
training_data['states'] = np.concatenate(self.training_data['states'], axis=0)
|
new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
|
||||||
training_data['probs'] = np.concatenate(self.training_data['probs'], axis=0)
|
|
||||||
training_data['winner'] = np.concatenate(self.training_data['winner'], axis=0)
|
|
||||||
|
|
||||||
if len(self.training_data['states']) != self.window_length:
|
if len(self.training_data['states']) != self.window_length:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
data_num = training_data['states'].shape[0]
|
|
||||||
index = np.arange(data_num)
|
|
||||||
np.random.shuffle(index)
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
for i in range(batch_size):
|
||||||
|
game_num = random.randint(0, self.window_length-1)
|
||||||
|
state_num = random.randint(0, self.training_data['length'][game_num]-1)
|
||||||
|
training_data['states'].append(np.expand_dims(self.training_data['states'][game_num][state_num], 0))
|
||||||
|
training_data['probs'].append(np.expand_dims(self.training_data['probs'][game_num][state_num], 0))
|
||||||
|
training_data['winner'].append(np.expand_dims(self.training_data['winner'][game_num][state_num], 0))
|
||||||
value_loss, policy_loss, reg, _ = self.sess.run(
|
value_loss, policy_loss, reg, _ = self.sess.run(
|
||||||
[self.value_loss, self.policy_loss, self.reg, self.train_op],
|
[self.value_loss, self.policy_loss, self.reg, self.train_op],
|
||||||
feed_dict={self.x: training_data['states'][index[:batch_size]],
|
feed_dict={self.x: np.concatenate(training_data['states'], axis=0),
|
||||||
self.z: training_data['winner'][index[:batch_size]],
|
self.z: np.concatenate(training_data['winner'], axis=0),
|
||||||
self.pi: training_data['probs'][index[:batch_size]],
|
self.pi: np.concatenate(training_data['probs'], axis=0),
|
||||||
self.is_training: True})
|
self.is_training: True})
|
||||||
|
|
||||||
print("Iteration: {}, Time: {}, Value Loss: {}, Policy Loss: {}, Reg: {}".format(iters,
|
print("Iteration: {}, Time: {}, Value Loss: {}, Policy Loss: {}, Reg: {}".format(iters,
|
||||||
time.time() - start_time,
|
time.time() - start_time,
|
||||||
value_loss,
|
value_loss,
|
||||||
policy_loss, reg))
|
policy_loss, reg))
|
||||||
iters += 1
|
|
||||||
if iters % self.save_freq == 0:
|
if iters % self.save_freq == 0:
|
||||||
save_path = "Iteration{}.ckpt".format(iters)
|
save_path = "Iteration{}.ckpt".format(iters)
|
||||||
self.saver.save(self.sess, self.checkpoint_path + save_path)
|
self.saver.save(self.sess, self.checkpoint_path + save_path)
|
||||||
|
for key in training_data.keys():
|
||||||
|
training_data[key] = []
|
||||||
|
iters += 1
|
||||||
|
|
||||||
def _file_to_training_data(self, file_name):
|
def _file_to_training_data(self, file_name):
|
||||||
read = False
|
read = False
|
||||||
@ -250,6 +255,7 @@ class ResNet(object):
|
|||||||
file.seek(0)
|
file.seek(0)
|
||||||
data = cPickle.load(file)
|
data = cPickle.load(file)
|
||||||
read = True
|
read = True
|
||||||
|
print("{} Loaded!".format(file_name))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
@ -275,6 +281,6 @@ class ResNet(object):
|
|||||||
return states, probs, winner
|
return states, probs, winner
|
||||||
|
|
||||||
|
|
||||||
if __name__=="__main__":
|
if __name__ == "__main__":
|
||||||
model = ResNet(board_size=9, action_num=82)
|
model = ResNet(board_size=9, action_num=82, history_length=8)
|
||||||
model.train("file", data_path="./data/", batch_size=128, checkpoint_path="./checkpoint/")
|
model.train("file", data_path="./data/", batch_size=128, checkpoint_path="./checkpoint/")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user