From 498b55c0517a624656ebccf2847af9872405136a Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Sat, 10 Mar 2018 17:30:11 +0800 Subject: [PATCH] ppo with batch also works! now ppo improves steadily, dqn not so stable. --- examples/ppo_cartpole.py | 16 ++++- tianshou/data/data_buffer/base.py | 14 +++- tianshou/data/data_buffer/batch_set.py | 88 +++++++++++++++++++++++++- tianshou/data/data_buffer/vanilla.py | 2 +- tianshou/data/data_collector.py | 43 +++++++------ 5 files changed, 136 insertions(+), 27 deletions(-) diff --git a/examples/ppo_cartpole.py b/examples/ppo_cartpole.py index bd8ab72..94a8155 100755 --- a/examples/ppo_cartpole.py +++ b/examples/ppo_cartpole.py @@ -6,6 +6,8 @@ import gym import numpy as np import time import argparse +import logging +logging.basicConfig(level=logging.INFO) # our lib imports here! It's ok to append path in examples import sys @@ -14,7 +16,7 @@ from tianshou.core import losses import tianshou.data.advantage_estimation as advantage_estimation import tianshou.core.policy.stochastic as policy -from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer +from tianshou.data.data_buffer.batch_set import BatchSet from tianshou.data.data_collector import DataCollector @@ -62,7 +64,15 @@ if __name__ == '__main__': train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) ### 3. define data collection - data_collector = Batch(env, pi, [advantage_estimation.full_return], [pi], render=args.render) + data_buffer = BatchSet() + + data_collector = DataCollector( + env=env, + policy=pi, + data_buffer=data_buffer, + process_functions=[advantage_estimation.full_return], + managed_networks=[pi], + ) ### 4. start training config = tf.ConfigProto() @@ -80,7 +90,7 @@ if __name__ == '__main__': # print current return print('Epoch {}:'.format(i)) - data_collector.statistics() + data_buffer.statistics() # update network for _ in range(num_batches): diff --git a/tianshou/data/data_buffer/base.py b/tianshou/data/data_buffer/base.py index bf377d9..db131f6 100644 --- a/tianshou/data/data_buffer/base.py +++ b/tianshou/data/data_buffer/base.py @@ -11,4 +11,16 @@ class DataBufferBase(object): raise NotImplementedError() def sample(self, batch_size): - raise NotImplementedError() \ No newline at end of file + prob_episode = np.array(self.index_lengths) * 1. / self.size + num_episodes = len(self.index) + sampled_index = [[] for _ in range(num_episodes)] + + for _ in range(batch_size): + # sample which episode + sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode)) + + # sample which data point within the sampled episode + sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i])) + sampled_index[sampled_episode_i].append(sampled_frame_i) + + return sampled_index diff --git a/tianshou/data/data_buffer/batch_set.py b/tianshou/data/data_buffer/batch_set.py index 57810f2..3272c69 100644 --- a/tianshou/data/data_buffer/batch_set.py +++ b/tianshou/data/data_buffer/batch_set.py @@ -1,11 +1,21 @@ +import gc +import numpy as np +import logging + from .base import DataBufferBase +STATE = 0 +ACTION = 1 +REWARD = 2 +DONE = 3 class BatchSet(DataBufferBase): """ class for batched dataset as used in on-policy algos """ - def __init__(self): + def __init__(self, nstep=None): + self.nstep = nstep or 1 # RL has to look ahead at least one timestep + self.data = [[]] self.index = [[]] self.candidate_index = 0 @@ -17,8 +27,80 @@ class BatchSet(DataBufferBase): def add(self, frame): self.data[-1].append(frame) + has_enough_frames = len(self.data[-1]) > self.nstep + if frame[DONE]: # episode terminates, all trailing frames become valid data points + trailing_index = list(range(self.candidate_index, len(self.data[-1]))) + self.index[-1] += trailing_index + self.size += len(trailing_index) + self.index_lengths[-1] += len(trailing_index) + + # prepare for the next episode + self.data.append([]) + self.index.append([]) + self.candidate_index = 0 + + self.index_lengths.append(0) + + elif has_enough_frames: # add one valid data point + self.index[-1].append(self.candidate_index) + self.candidate_index += 1 + self.size += 1 + self.index_lengths[-1] += 1 + def clear(self): - pass + del self.data + del self.index + del self.index_lengths + + gc.collect() + + self.data = [[]] + self.index = [[]] + self.candidate_index = 0 + self.size = 0 + self.index_lengths = [0] def sample(self, batch_size): - pass + # TODO: move unified properties and methods to base. but this depends on how to deal with nstep + + prob_episode = np.array(self.index_lengths) * 1. / self.size + num_episodes = len(self.index) + sampled_index = [[] for _ in range(num_episodes)] + + for _ in range(batch_size): + # sample which episode + sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode)) + + # sample which data point within the sampled episode + sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i])) + sampled_index[sampled_episode_i].append(sampled_frame_i) + + return sampled_index + + def statistics(self, discount_factor=0.99): + returns = [] + undiscounted_returns = [] + + if len(self.data) == 1: + data = self.data + logging.warning('The first episode in BatchSet is still not finished. ' + 'Logging its return anyway.') + else: + data = self.data[:-1] + + for episode in data: + current_return = 0. + current_undiscounted_return = 0. + current_discount = 1. + for frame in episode: + current_return += frame[REWARD] * current_discount + current_undiscounted_return += frame[REWARD] + current_discount *= discount_factor + returns.append(current_return) + undiscounted_returns.append(current_undiscounted_return) + + mean_return = np.mean(returns) + mean_undiscounted_return = np.mean(undiscounted_returns) + + logging.info('Mean return: {}'.format(mean_return)) + logging.info('Mean undiscounted return: {}'.format(mean_undiscounted_return)) diff --git a/tianshou/data/data_buffer/vanilla.py b/tianshou/data/data_buffer/vanilla.py index 40d88dd..5feb55e 100644 --- a/tianshou/data/data_buffer/vanilla.py +++ b/tianshou/data/data_buffer/vanilla.py @@ -8,7 +8,7 @@ ACTION = 1 REWARD = 2 DONE = 3 -# TODO: valid data points could be less than `nstep` timesteps +# TODO: valid data points could be less than `nstep` timesteps. Check priority replay paper! class VanillaReplayBuffer(ReplayBufferBase): """ vanilla replay buffer as used in (Mnih, et al., 2015). diff --git a/tianshou/data/data_collector.py b/tianshou/data/data_collector.py index c3e1879..c887b78 100644 --- a/tianshou/data/data_collector.py +++ b/tianshou/data/data_collector.py @@ -1,10 +1,10 @@ import numpy as np import logging import itertools -import sys from .data_buffer.replay_buffer_base import ReplayBufferBase from .data_buffer.batch_set import BatchSet +from .utils import internal_key_match class DataCollector(object): """ @@ -32,7 +32,7 @@ class DataCollector(object): self.current_observation = self.env.reset() - def collect(self, num_timesteps=1, num_episodes=0, my_feed_dict={}, auto_clear=True): + def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, auto_clear=True): assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\ "One and only one collection number specification permitted!" @@ -76,26 +76,34 @@ class DataCollector(object): feed_dict = {} frame_key_map = {'observation': 0, 'action': 1, 'reward': 2, 'done_flag': 3} for key, placeholder in self.required_placeholders.items(): - if key in frame_key_map.keys(): # access raw_data - frame_index = frame_key_map[key] + # check raw_data first + found, matched_key = internal_key_match(key, frame_key_map.keys()) + if found: + frame_index = frame_key_map[matched_key] flattened = [] for index_episode, data_episode in zip(sampled_index, self.data_buffer.data): for i in index_episode: flattened.append(data_episode[i][frame_index]) feed_dict[placeholder] = np.array(flattened) - elif key in self.data_batch.keys(): # access processed minibatch data - flattened = list(itertools.chain.from_iterable(self.data_batch[key])) - feed_dict[placeholder] = np.array(flattened) - elif key in self.data.keys(): # access processed full data - flattened = [0.] * batch_size # float - i_in_batch = 0 - for index_episode, data_episode in zip(sampled_index, self.data[key]): - for i in index_episode: - flattened[i_in_batch] = data_episode[i] - i_in_batch += 1 - feed_dict[placeholder] = np.array(flattened) else: - raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name))) + # then check processed minibatch data + found, matched_key = internal_key_match(key, self.data_batch.keys()) + if found: + flattened = list(itertools.chain.from_iterable(self.data_batch[matched_key])) + feed_dict[placeholder] = np.array(flattened) + else: + # finally check processed full data + found, matched_key = internal_key_match(key, self.data.keys()) + if found: + flattened = [0.] * batch_size # float + i_in_batch = 0 + for index_episode, data_episode in zip(sampled_index, self.data[matched_key]): + for i in index_episode: + flattened[i_in_batch] = data_episode[i] + i_in_batch += 1 + feed_dict[placeholder] = np.array(flattened) + else: + raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name))) auto_standardize = (standardize_advantage is None) and self.require_advantage if standardize_advantage or auto_standardize: @@ -108,6 +116,3 @@ class DataCollector(object): feed_dict[self.required_placeholders['advantage']] = (advantage_value - advantage_mean) / advantage_std return feed_dict - - def statistics(self): - pass \ No newline at end of file