bug fix of model.py

This commit is contained in:
Dong Yan 2017-12-23 17:43:33 +08:00
parent 4589fcf521
commit 919784e88b

View File

@ -101,7 +101,7 @@ class ResNet(object):
self._build_network(residual_block_num, self.checkpoint_path)
# training hyper-parameters:
self.window_length = 7000
self.window_length = 3
self.save_freq = 5000
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
@ -223,8 +223,8 @@ class ResNet(object):
else:
start_time = time.time()
for i in range(batch_size):
priority = self.training_data['length'] / sum(self.training_data['length'])
game_num = np.random.choice(self.window_length, 1, p=priority)
priority = np.array(self.training_data['length']) / (0.0 + np.sum(np.array(self.training_data['length'])))
game_num = np.random.choice(self.window_length, 1, p=priority)[0]
state_num = np.random.randint(self.training_data['length'][game_num])
rotate_times = np.random.randint(4)
reflect_times = np.random.randint(2)
@ -232,12 +232,10 @@ class ResNet(object):
training_data['states'].append(
self._preprocession(self.training_data['states'][game_num][state_num], reflect_times,
reflect_orientation, rotate_times))
training_data['probs'].append(
self._preprocession(self.training_data['probs'][game_num][state_num], reflect_times,
reflect_orientation, rotate_times))
training_data['winner'].append(
self._preprocession(self.training_data['winner'][game_num][state_num], reflect_times,
reflect_orientation, rotate_times))
training_data['probs'].append(np.concatenate(
[self._preprocession(self.training_data['probs'][game_num][state_num][:-1].reshape(self.board_size, self.board_size, 1), reflect_times,
reflect_orientation, rotate_times).reshape(1, self.board_size**2), self.training_data['probs'][game_num][state_num][-1].reshape(1,1)], axis=1))
training_data['winner'].append(self.training_data['winner'][game_num][state_num].reshape(1, 1))
value_loss, policy_loss, reg, _ = self.sess.run(
[self.value_loss, self.policy_loss, self.reg, self.train_op],
feed_dict={self.x: np.concatenate(training_data['states'], axis=0),
@ -302,7 +300,7 @@ class ResNet(object):
new_board = copy.copy(board)
if new_board.ndim == 3:
np.expand_dims(new_board, axis=0)
new_board = np.expand_dims(new_board, axis=0)
new_board = self._board_reflection(new_board, reflect_times, reflect_orientation)
new_board = self._board_rotation(new_board, rotate_times)