bug fix of model.py
This commit is contained in:
parent
4589fcf521
commit
919784e88b
@ -101,7 +101,7 @@ class ResNet(object):
|
||||
self._build_network(residual_block_num, self.checkpoint_path)
|
||||
|
||||
# training hyper-parameters:
|
||||
self.window_length = 7000
|
||||
self.window_length = 3
|
||||
self.save_freq = 5000
|
||||
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
|
||||
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
|
||||
@ -223,8 +223,8 @@ class ResNet(object):
|
||||
else:
|
||||
start_time = time.time()
|
||||
for i in range(batch_size):
|
||||
priority = self.training_data['length'] / sum(self.training_data['length'])
|
||||
game_num = np.random.choice(self.window_length, 1, p=priority)
|
||||
priority = np.array(self.training_data['length']) / (0.0 + np.sum(np.array(self.training_data['length'])))
|
||||
game_num = np.random.choice(self.window_length, 1, p=priority)[0]
|
||||
state_num = np.random.randint(self.training_data['length'][game_num])
|
||||
rotate_times = np.random.randint(4)
|
||||
reflect_times = np.random.randint(2)
|
||||
@ -232,12 +232,10 @@ class ResNet(object):
|
||||
training_data['states'].append(
|
||||
self._preprocession(self.training_data['states'][game_num][state_num], reflect_times,
|
||||
reflect_orientation, rotate_times))
|
||||
training_data['probs'].append(
|
||||
self._preprocession(self.training_data['probs'][game_num][state_num], reflect_times,
|
||||
reflect_orientation, rotate_times))
|
||||
training_data['winner'].append(
|
||||
self._preprocession(self.training_data['winner'][game_num][state_num], reflect_times,
|
||||
reflect_orientation, rotate_times))
|
||||
training_data['probs'].append(np.concatenate(
|
||||
[self._preprocession(self.training_data['probs'][game_num][state_num][:-1].reshape(self.board_size, self.board_size, 1), reflect_times,
|
||||
reflect_orientation, rotate_times).reshape(1, self.board_size**2), self.training_data['probs'][game_num][state_num][-1].reshape(1,1)], axis=1))
|
||||
training_data['winner'].append(self.training_data['winner'][game_num][state_num].reshape(1, 1))
|
||||
value_loss, policy_loss, reg, _ = self.sess.run(
|
||||
[self.value_loss, self.policy_loss, self.reg, self.train_op],
|
||||
feed_dict={self.x: np.concatenate(training_data['states'], axis=0),
|
||||
@ -302,7 +300,7 @@ class ResNet(object):
|
||||
|
||||
new_board = copy.copy(board)
|
||||
if new_board.ndim == 3:
|
||||
np.expand_dims(new_board, axis=0)
|
||||
new_board = np.expand_dims(new_board, axis=0)
|
||||
|
||||
new_board = self._board_reflection(new_board, reflect_times, reflect_orientation)
|
||||
new_board = self._board_rotation(new_board, rotate_times)
|
||||
|
Loading…
x
Reference in New Issue
Block a user