diff --git a/AlphaGo/game.py b/AlphaGo/game.py index 8706572..df08c0a 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -31,7 +31,7 @@ class Game: self.latest_boards = deque(maxlen=8) for _ in range(8): self.latest_boards.append(self.board) - self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8) + self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8, checkpoint_path=checkpoint_path) # self.evaluator = lambda state: self.sess.run([tf.nn.softmax(self.net.p), self.net.v], # feed_dict={self.net.x: state, self.net.is_training: False}) self.game_engine = go.Go(size=self.size, komi=self.komi) @@ -96,7 +96,7 @@ class Game: sys.stdout.flush() if __name__ == "__main__": - g = Game() + g = Game(checkpoint_path='./checkpoints/') g.show_board() g.think_play_move(1) #file = open("debug.txt", "a") diff --git a/AlphaGo/model.py b/AlphaGo/model.py index 764ba5f..22e8626 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -1,5 +1,6 @@ import os import time +import random import sys import cPickle from collections import deque @@ -104,7 +105,7 @@ class ResNet(object): self.window_length = 7000 self.save_freq = 5000 self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length), - 'winner': deque(maxlen=self.window_length)} + 'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)} def _build_network(self, residual_block_num, checkpoint_path): """ @@ -199,15 +200,15 @@ class ResNet(object): new_file_list = [] all_file_list = [] - training_data = {} + training_data = {'states': [], 'probs': [], 'winner': []} + iters = 0 while True: new_file_list = list(set(os.listdir(data_path)).difference(all_file_list)) - if new_file_list: + while new_file_list: all_file_list = os.listdir(data_path) - new_file_list.sort( - key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir(data_path + file) else 0) - if new_file_list: + new_file_list.sort( + key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir(data_path + file) else 0) for file in new_file_list: states, probs, winner = self._file_to_training_data(data_path + file) assert states.shape[0] == probs.shape[0] @@ -215,32 +216,36 @@ class ResNet(object): self.training_data['states'].append(states) self.training_data['probs'].append(probs) self.training_data['winner'].append(winner) - if len(self.training_data['states']) == self.window_length: - training_data['states'] = np.concatenate(self.training_data['states'], axis=0) - training_data['probs'] = np.concatenate(self.training_data['probs'], axis=0) - training_data['winner'] = np.concatenate(self.training_data['winner'], axis=0) + self.training_data['length'].append(states.shape[0]) + new_file_list = list(set(os.listdir(data_path)).difference(all_file_list)) if len(self.training_data['states']) != self.window_length: continue else: - data_num = training_data['states'].shape[0] - index = np.arange(data_num) - np.random.shuffle(index) start_time = time.time() + for i in range(batch_size): + game_num = random.randint(0, self.window_length-1) + state_num = random.randint(0, self.training_data['length'][game_num]-1) + training_data['states'].append(np.expand_dims(self.training_data['states'][game_num][state_num], 0)) + training_data['probs'].append(np.expand_dims(self.training_data['probs'][game_num][state_num], 0)) + training_data['winner'].append(np.expand_dims(self.training_data['winner'][game_num][state_num], 0)) value_loss, policy_loss, reg, _ = self.sess.run( [self.value_loss, self.policy_loss, self.reg, self.train_op], - feed_dict={self.x: training_data['states'][index[:batch_size]], - self.z: training_data['winner'][index[:batch_size]], - self.pi: training_data['probs'][index[:batch_size]], + feed_dict={self.x: np.concatenate(training_data['states'], axis=0), + self.z: np.concatenate(training_data['winner'], axis=0), + self.pi: np.concatenate(training_data['probs'], axis=0), self.is_training: True}) + print("Iteration: {}, Time: {}, Value Loss: {}, Policy Loss: {}, Reg: {}".format(iters, time.time() - start_time, value_loss, policy_loss, reg)) - iters += 1 if iters % self.save_freq == 0: save_path = "Iteration{}.ckpt".format(iters) self.saver.save(self.sess, self.checkpoint_path + save_path) + for key in training_data.keys(): + training_data[key] = [] + iters += 1 def _file_to_training_data(self, file_name): read = False @@ -250,6 +255,7 @@ class ResNet(object): file.seek(0) data = cPickle.load(file) read = True + print("{} Loaded!".format(file_name)) except Exception as e: print(e) time.sleep(1) @@ -275,6 +281,6 @@ class ResNet(object): return states, probs, winner -if __name__=="__main__": - model = ResNet(board_size=9, action_num=82) +if __name__ == "__main__": + model = ResNet(board_size=9, action_num=82, history_length=8) model.train("file", data_path="./data/", batch_size=128, checkpoint_path="./checkpoint/")