diff --git a/examples/ppo_cartpole.py b/examples/ppo_cartpole.py index 074331d..0054c9a 100755 --- a/examples/ppo_cartpole.py +++ b/examples/ppo_cartpole.py @@ -59,7 +59,7 @@ if __name__ == '__main__': train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) ### 3. define data collection - training_data = Batch(env, pi, [advantage_estimation.full_return], [pi], render = args.render) + training_data = Batch(env, pi, [advantage_estimation.full_return], [pi], render=args.render) ### 4. start training config = tf.ConfigProto() diff --git a/internal_keys.md b/internal_keys.md index 9f7e4bd..78571a0 100644 --- a/internal_keys.md +++ b/internal_keys.md @@ -10,6 +10,6 @@ data_collector.data.keys() ['reward'] -['start_flag'] +['done_flag'] ['advantage'] > ['return'] # they may appear simultaneously \ No newline at end of file diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py index 8d9c25f..eafc30d 100644 --- a/tianshou/data/advantage_estimation.py +++ b/tianshou/data/advantage_estimation.py @@ -1,39 +1,48 @@ -import numpy as np +import logging -def full_return(raw_data): +STATE = 0 +ACTION = 1 +REWARD = 2 +DONE = 3 + +# modified for new interfaces +def full_return(buffer, index=None): """ naively compute full return - :param raw_data: dict of specified keys and values. + :param buffer: buffer with property index and data. index determines the current content in `buffer`. + :param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within + each episode. + :return: dict with key 'return' and value the computed returns corresponding to `index`. """ - observations = raw_data['observation'] - actions = raw_data['action'] - rewards = raw_data['reward'] - episode_start_flags = raw_data['end_flag'] - num_timesteps = rewards.shape[0] + index = index or buffer.index + raw_data = buffer.data - data = {} + returns = [] + for i_episode in range(len(index)): + index_this = index[i_episode] + if index_this: + episode = raw_data[i_episode] + if not episode[-1][DONE]: + logging.warning('Computing full return on episode {} with no terminal state.'.format(i_episode)) - returns = rewards.copy() - episode_start_idx = 0 - for i in range(1, num_timesteps): - if episode_start_flags[i] or ( - i == num_timesteps - 1): # found the start of next episode or the end of all episodes - if i < rewards.shape[0] - 1: - t = i - 1 - else: - t = i - Gt = 0 - while t >= episode_start_idx: - Gt += rewards[t] - returns[t] = Gt - t -= 1 + episode_length = len(episode) + returns_episode = [0.] * episode_length + returns_this = [0.] * len(index_this) + return_ = 0. + index_min = min(index_this) + for i, frame in zip(range(episode_length - 1, index_min - 1, -1), reversed(episode[index_min:])): + return_ += frame[REWARD] + returns_episode[i] = return_ - episode_start_idx = i + for i in range(len(index_this)): + returns_this[i] = returns_episode[index_this[i]] - data['return'] = returns + returns.append(returns_this) + else: + returns.append([]) - return data + return {'return': returns} class gae_lambda: @@ -44,16 +53,14 @@ class gae_lambda: self.T = T self.value_function = value_function - def __call__(self, raw_data): - reward = raw_data['reward'] - observation = raw_data['observation'] - - state_value = self.value_function.eval_value(observation) - - # wrong version of advantage just to run - advantage = reward + state_value - - return {'advantage': advantage} + def __call__(self, buffer, index=None): + """ + :param buffer: buffer with property index and data. index determines the current content in `buffer`. + :param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within + each episode. + :return: dict with key 'advantage' and value the computed advantages corresponding to `index`. + """ + pass class nstep_return: @@ -64,16 +71,15 @@ class nstep_return: self.n = n self.value_function = value_function - def __call__(self, raw_data): - reward = raw_data['reward'] - observation = raw_data['observation'] - - state_value = self.value_function.eval_value(observation) - - # wrong version of return just to run - return_ = reward + state_value - - return {'return': return_} + def __call__(self, buffer, index=None): + """ + naively compute full return + :param buffer: buffer with property index and data. index determines the current content in `buffer`. + :param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within + each episode. + :return: dict with key 'return' and value the computed returns corresponding to `index`. + """ + pass class ddpg_return: @@ -85,20 +91,15 @@ class ddpg_return: self.critic = critic self.use_target_network = use_target_network - def __call__(self, raw_data): - observation = raw_data['observation'] - reward = raw_data['reward'] - - if self.use_target_network: - action_target = self.actor.eval_action_old(observation) - value_target = self.critic.eval_value_old(observation, action_target) - else: - action_target = self.actor.eval_action(observation) - value_target = self.critic.eval_value(observation, action_target) - - return_ = reward + value_target - - return {'return': return_} + def __call__(self, buffer, index=None): + """ + naively compute full return + :param buffer: buffer with property index and data. index determines the current content in `buffer`. + :param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within + each episode. + :return: dict with key 'return' and value the computed returns corresponding to `index`. + """ + pass class nstep_q_return: @@ -110,80 +111,12 @@ class nstep_q_return: self.action_value = action_value self.use_target_network = use_target_network - def __call__(self, raw_data): - # raw_data should contain 'next_observation' from replay memory...? - # maybe the main difference between Batch and Replay is the stored data format? - reward = raw_data['reward'] - observation = raw_data['observation'] - - if self.use_target_network: - action_value_all_actions = self.action_value.eval_value_all_actions_old(observation) - else: - action_value_all_actions = self.action_value.eval_value_all_actions(observation) - - action_value_max = np.max(action_value_all_actions, axis=1) - - return_ = reward + action_value_max - - return {'return': return_} - - -class QLearningTarget: - def __init__(self, policy, gamma): - self._policy = policy - self._gamma = gamma - - def __call__(self, raw_data): - data = dict() - observations = list() - actions = list() - rewards = list() - wi = list() - all_data, data_wi, data_index = raw_data - - for i in range(0, all_data.shape[0]): - current_data = all_data[i] - current_wi = data_wi[i] - current_index = data_index[i] - observations.append(current_data['observation']) - actions.append(current_data['action']) - next_max_qvalue = np.max(self._policy.values(current_data['observation'])) - current_qvalue = self._policy.values(current_data['previous_observation'])[current_data['previous_action']] - reward = current_data['reward'] + next_max_qvalue - current_qvalue - rewards.append(reward) - wi.append(current_wi) - - data['observations'] = np.array(observations) - data['actions'] = np.array(actions) - data['rewards'] = np.array(rewards) - - return data - - -class ReplayMemoryQReturn: - """ - compute the n-step return for Q-learning targets - """ - def __init__(self, n, action_value, use_target_network=True): - self.n = n - self._action_value = action_value - self._use_target_network = use_target_network - - def __call__(self, raw_data): - reward = raw_data['reward'] - observation = raw_data['observation'] - - if self._use_target_network: - # print(observation.shape) - # print((observation.reshape((1,) + observation.shape))) - action_value_all_actions = self._action_value.eval_value_all_actions_old(observation.reshape((1,) + observation.shape)) - else: - # print(observation.shape) - # print((observation.reshape((1,) + observation.shape))) - action_value_all_actions = self._action_value.eval_value_all_actions(observation.reshape((1,) + observation.shape)) - - action_value_max = np.max(action_value_all_actions, axis=1) - - return_ = reward + action_value_max - - return {'return': return_} + def __call__(self, buffer, index=None): + """ + naively compute full return + :param buffer: buffer with property index and data. index determines the current content in `buffer`. + :param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within + each episode. + :return: dict with key 'return' and value the computed returns corresponding to `index`. + """ + pass diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index d559ded..a1684ee 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -156,11 +156,19 @@ class Batch(object): def next_batch(self, batch_size, standardize_advantage=True): rand_idx = np.random.choice(self.raw_data['observation'].shape[0], batch_size) + # maybe re-compute advantage here, but only on rand_idx + # but how to construct the feed_dict? + if self.online: + self.data_batch.update(self.apply_advantage_estimation_function(rand_idx)) + + feed_dict = {} for key, placeholder in self.required_placeholders.items(): + feed_dict[placeholder] = utils.get_batch(self, key, rand_idx) + found, data_key = utils.internal_key_match(key, self.raw_data.keys()) if found: - feed_dict[placeholder] = self.raw_data[data_key][rand_idx] + feed_dict[placeholder] = utils.get_batch(self.raw_data[data_key], rand_idx) # self.raw_data[data_key][rand_idx] else: found, data_key = utils.internal_key_match(key, self.data.keys()) if found: diff --git a/tianshou/data/test_advantage_estimation.py b/tianshou/data/test_advantage_estimation.py new file mode 100644 index 0000000..47589db --- /dev/null +++ b/tianshou/data/test_advantage_estimation.py @@ -0,0 +1,29 @@ + + +from advantage_estimation import * + +class ReplayBuffer(object): + def __init__(self): + self.index = [ + [0, 1, 2], + [0, 1, 2, 3], + [0, 1], + ] + self.data = [ + [(0, 0, 10, False), (0, 0, 1, False), (0, 0, -100, True)], + [(0, 0, 1, False), (0, 0, 1, False), (0, 0, 1, False), (0, 0, 5, False)], + [(0, 0, 9, False), (0, 0, -2, True)], + ] + + +buffer = ReplayBuffer() +sample_index = [ + [0, 2, 0], + [1, 2, 1, 3], + [], +] + +data = full_return(buffer) +print(data['return']) +data = full_return(buffer, sample_index) +print(data['return']) \ No newline at end of file