add random preprocess, modify the uniform sample from training data
This commit is contained in:
parent
c50ee8f029
commit
a787f73cf6
@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import random
|
import copy
|
||||||
import sys
|
|
||||||
import cPickle
|
import cPickle
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
@ -224,11 +223,21 @@ class ResNet(object):
|
|||||||
else:
|
else:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
game_num = random.randint(0, self.window_length-1)
|
priority = self.training_data['length'] / sum(self.training_data['length'])
|
||||||
state_num = random.randint(0, self.training_data['length'][game_num]-1)
|
game_num = np.random.choice(self.window_length, 1, p=priority)
|
||||||
training_data['states'].append(np.expand_dims(self.training_data['states'][game_num][state_num], 0))
|
state_num = np.random.randint(self.training_data['length'][game_num])
|
||||||
training_data['probs'].append(np.expand_dims(self.training_data['probs'][game_num][state_num], 0))
|
rotate_times = np.random.randint(4)
|
||||||
training_data['winner'].append(np.expand_dims(self.training_data['winner'][game_num][state_num], 0))
|
reflect_times = np.random.randint(2)
|
||||||
|
reflect_orientation = np.random.randint(2)
|
||||||
|
training_data['states'].append(
|
||||||
|
self._preprocession(self.training_data['states'][game_num][state_num], reflect_times,
|
||||||
|
reflect_orientation, rotate_times))
|
||||||
|
training_data['probs'].append(
|
||||||
|
self._preprocession(self.training_data['probs'][game_num][state_num], reflect_times,
|
||||||
|
reflect_orientation, rotate_times))
|
||||||
|
training_data['winner'].append(
|
||||||
|
self._preprocession(self.training_data['winner'][game_num][state_num], reflect_times,
|
||||||
|
reflect_orientation, rotate_times))
|
||||||
value_loss, policy_loss, reg, _ = self.sess.run(
|
value_loss, policy_loss, reg, _ = self.sess.run(
|
||||||
[self.value_loss, self.policy_loss, self.reg, self.train_op],
|
[self.value_loss, self.policy_loss, self.reg, self.train_op],
|
||||||
feed_dict={self.x: np.concatenate(training_data['states'], axis=0),
|
feed_dict={self.x: np.concatenate(training_data['states'], axis=0),
|
||||||
@ -280,6 +289,55 @@ class ResNet(object):
|
|||||||
winner = np.concatenate(winner, axis=0)
|
winner = np.concatenate(winner, axis=0)
|
||||||
return states, probs, winner
|
return states, probs, winner
|
||||||
|
|
||||||
|
def _preprocession(self, board, reflect_times=0, reflect_orientation=0, rotate_times=0):
|
||||||
|
"""
|
||||||
|
preprocessing for augmentation
|
||||||
|
|
||||||
|
:param board: a ndarray, board to process
|
||||||
|
:param reflect_times: an integer, how many times to reflect
|
||||||
|
:param reflect_orientation: an integer, which orientation to reflect
|
||||||
|
:param rotate_times: an integer, how many times to rotate
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
new_board = copy.copy(board)
|
||||||
|
if new_board.ndim == 3:
|
||||||
|
np.expand_dims(new_board, axis=0)
|
||||||
|
|
||||||
|
new_board = self._board_reflection(new_board, reflect_times, reflect_orientation)
|
||||||
|
new_board = self._board_rotation(new_board, rotate_times)
|
||||||
|
|
||||||
|
return new_board
|
||||||
|
|
||||||
|
def _board_rotation(self, board, times):
|
||||||
|
"""
|
||||||
|
rotate the board for augmentation
|
||||||
|
note that board's shape should be [batch_size, board_size, board_size, channels]
|
||||||
|
|
||||||
|
:param board: a ndarray, shape [batch_size, board_size, board_size, channels]
|
||||||
|
:param times: an integer, how many times to rotate
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return np.rot90(board, times, (1, 2))
|
||||||
|
|
||||||
|
def _board_reflection(self, board, times, orientation):
|
||||||
|
"""
|
||||||
|
reflect the board for augmentation
|
||||||
|
note that board's shape should be [batch_size, board_size, board_size, channels]
|
||||||
|
|
||||||
|
:param board: a ndarray, shape [batch_size, board_size, board_size, channels]
|
||||||
|
:param times: an integer, how many times to reflect
|
||||||
|
:param orientation: an integer, which orientation to reflect
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
new_board = copy.copy(board)
|
||||||
|
for _ in range(times):
|
||||||
|
if orientation == 0:
|
||||||
|
new_board = new_board[:, ::-1]
|
||||||
|
if orientation == 1:
|
||||||
|
new_board = new_board[:, :, ::-1]
|
||||||
|
return new_board
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model = ResNet(board_size=9, action_num=82, history_length=8)
|
model = ResNet(board_size=9, action_num=82, history_length=8)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user