From 4589fcf52194eccc219f82e36345573541511674 Mon Sep 17 00:00:00 2001 From: rtz19970824 <1289226405@qq.com> Date: Sat, 23 Dec 2017 16:27:09 +0800 Subject: [PATCH] add random preprocess, modify the uniform sample from training data --- AlphaGo/model.py | 72 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 65 insertions(+), 7 deletions(-) diff --git a/AlphaGo/model.py b/AlphaGo/model.py index 22e8626..68973ac 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -1,7 +1,6 @@ import os import time -import random -import sys +import copy import cPickle from collections import deque @@ -224,11 +223,21 @@ class ResNet(object): else: start_time = time.time() for i in range(batch_size): - game_num = random.randint(0, self.window_length-1) - state_num = random.randint(0, self.training_data['length'][game_num]-1) - training_data['states'].append(np.expand_dims(self.training_data['states'][game_num][state_num], 0)) - training_data['probs'].append(np.expand_dims(self.training_data['probs'][game_num][state_num], 0)) - training_data['winner'].append(np.expand_dims(self.training_data['winner'][game_num][state_num], 0)) + priority = self.training_data['length'] / sum(self.training_data['length']) + game_num = np.random.choice(self.window_length, 1, p=priority) + state_num = np.random.randint(self.training_data['length'][game_num]) + rotate_times = np.random.randint(4) + reflect_times = np.random.randint(2) + reflect_orientation = np.random.randint(2) + training_data['states'].append( + self._preprocession(self.training_data['states'][game_num][state_num], reflect_times, + reflect_orientation, rotate_times)) + training_data['probs'].append( + self._preprocession(self.training_data['probs'][game_num][state_num], reflect_times, + reflect_orientation, rotate_times)) + training_data['winner'].append( + self._preprocession(self.training_data['winner'][game_num][state_num], reflect_times, + reflect_orientation, rotate_times)) value_loss, policy_loss, reg, _ = self.sess.run( [self.value_loss, self.policy_loss, self.reg, self.train_op], feed_dict={self.x: np.concatenate(training_data['states'], axis=0), @@ -280,6 +289,55 @@ class ResNet(object): winner = np.concatenate(winner, axis=0) return states, probs, winner + def _preprocession(self, board, reflect_times=0, reflect_orientation=0, rotate_times=0): + """ + preprocessing for augmentation + + :param board: a ndarray, board to process + :param reflect_times: an integer, how many times to reflect + :param reflect_orientation: an integer, which orientation to reflect + :param rotate_times: an integer, how many times to rotate + :return: + """ + + new_board = copy.copy(board) + if new_board.ndim == 3: + np.expand_dims(new_board, axis=0) + + new_board = self._board_reflection(new_board, reflect_times, reflect_orientation) + new_board = self._board_rotation(new_board, rotate_times) + + return new_board + + def _board_rotation(self, board, times): + """ + rotate the board for augmentation + note that board's shape should be [batch_size, board_size, board_size, channels] + + :param board: a ndarray, shape [batch_size, board_size, board_size, channels] + :param times: an integer, how many times to rotate + :return: + """ + return np.rot90(board, times, (1, 2)) + + def _board_reflection(self, board, times, orientation): + """ + reflect the board for augmentation + note that board's shape should be [batch_size, board_size, board_size, channels] + + :param board: a ndarray, shape [batch_size, board_size, board_size, channels] + :param times: an integer, how many times to reflect + :param orientation: an integer, which orientation to reflect + :return: + """ + new_board = copy.copy(board) + for _ in range(times): + if orientation == 0: + new_board = new_board[:, ::-1] + if orientation == 1: + new_board = new_board[:, :, ::-1] + return new_board + if __name__ == "__main__": model = ResNet(board_size=9, action_num=82, history_length=8)