From 919784e88b011028ff5e8b8e226974a9bbf8d75c Mon Sep 17 00:00:00 2001 From: Dong Yan Date: Sat, 23 Dec 2017 17:43:33 +0800 Subject: [PATCH] bug fix of model.py --- AlphaGo/model.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/AlphaGo/model.py b/AlphaGo/model.py index 68973ac..2dc1ef0 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -101,7 +101,7 @@ class ResNet(object): self._build_network(residual_block_num, self.checkpoint_path) # training hyper-parameters: - self.window_length = 7000 + self.window_length = 3 self.save_freq = 5000 self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length), 'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)} @@ -223,8 +223,8 @@ class ResNet(object): else: start_time = time.time() for i in range(batch_size): - priority = self.training_data['length'] / sum(self.training_data['length']) - game_num = np.random.choice(self.window_length, 1, p=priority) + priority = np.array(self.training_data['length']) / (0.0 + np.sum(np.array(self.training_data['length']))) + game_num = np.random.choice(self.window_length, 1, p=priority)[0] state_num = np.random.randint(self.training_data['length'][game_num]) rotate_times = np.random.randint(4) reflect_times = np.random.randint(2) @@ -232,12 +232,10 @@ class ResNet(object): training_data['states'].append( self._preprocession(self.training_data['states'][game_num][state_num], reflect_times, reflect_orientation, rotate_times)) - training_data['probs'].append( - self._preprocession(self.training_data['probs'][game_num][state_num], reflect_times, - reflect_orientation, rotate_times)) - training_data['winner'].append( - self._preprocession(self.training_data['winner'][game_num][state_num], reflect_times, - reflect_orientation, rotate_times)) + training_data['probs'].append(np.concatenate( + [self._preprocession(self.training_data['probs'][game_num][state_num][:-1].reshape(self.board_size, self.board_size, 1), reflect_times, + reflect_orientation, rotate_times).reshape(1, self.board_size**2), self.training_data['probs'][game_num][state_num][-1].reshape(1,1)], axis=1)) + training_data['winner'].append(self.training_data['winner'][game_num][state_num].reshape(1, 1)) value_loss, policy_loss, reg, _ = self.sess.run( [self.value_loss, self.policy_loss, self.reg, self.train_op], feed_dict={self.x: np.concatenate(training_data['states'], axis=0), @@ -302,7 +300,7 @@ class ResNet(object): new_board = copy.copy(board) if new_board.ndim == 3: - np.expand_dims(new_board, axis=0) + new_board = np.expand_dims(new_board, axis=0) new_board = self._board_reflection(new_board, reflect_times, reflect_orientation) new_board = self._board_rotation(new_board, rotate_times)