From 5f979caf58b6a31f335e999dcbd1211dc668023a Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Sun, 15 Apr 2018 17:41:43 +0800 Subject: [PATCH] finish all API docs, first version. --- tianshou/core/losses.py | 12 +- tianshou/core/opt.py | 2 +- tianshou/core/policy/deterministic.py | 4 +- tianshou/core/random.py | 13 ++ tianshou/core/utils.py | 6 +- tianshou/core/value_function/action_value.py | 12 +- tianshou/data/README.md | 26 ---- tianshou/data/advantage_estimation.py | 130 +++++++++++++----- tianshou/data/data_buffer/__init__.py | 8 +- tianshou/data/data_buffer/base.py | 34 +++-- tianshou/data/data_buffer/batch_set.py | 41 +++++- .../data/data_buffer/replay_buffer_base.py | 9 +- tianshou/data/data_buffer/vanilla.py | 44 +++--- tianshou/data/data_collector.py | 68 +++++++-- tianshou/data/tester.py | 25 +++- 15 files changed, 296 insertions(+), 138 deletions(-) delete mode 100644 tianshou/data/README.md diff --git a/tianshou/core/losses.py b/tianshou/core/losses.py index 5d4ff93..1c329de 100644 --- a/tianshou/core/losses.py +++ b/tianshou/core/losses.py @@ -4,9 +4,11 @@ import tensorflow as tf def ppo_clip(policy, clip_param): """ Builds the graph of clipped loss :math:`L^{CLIP}` as in the - `Link PPO paper `_, which is basically - :math:`-\min(r_t(\\theta)Aˆt, clip(r_t(\\theta), 1 - \epsilon, 1 + \epsilon)Aˆt)`. + `PPO paper `_, which is basically + :math:`-\min(r_t(\\theta)A_t, \mathrm{clip}(r_t(\\theta), 1 - \epsilon, 1 + \epsilon)A_t)`. We minimize the objective instead of maximizing, hence the leading negative sign. + It creates an action placeholder and an advantage placeholder and adds into the ``managed_placeholders`` + of the ``policy``. :param policy: A :class:`tianshou.core.policy` to be optimized. :param clip param: A float or Tensor of type float. The :math:`\epsilon` in the loss equation. @@ -29,8 +31,10 @@ def ppo_clip(policy, clip_param): def REINFORCE(policy): """ Builds the graph of the loss function as used in vanilla policy gradient algorithms, i.e., REINFORCE. - The loss is basically :math:`\log \pi(a|s) A^t`. + The loss is basically :math:`\log \pi(a|s) A_t`. We minimize the objective instead of maximizing, hence the leading negative sign. + It creates an action placeholder and an advantage placeholder and adds into the ``managed_placeholders`` + of the ``policy``. :param policy: A :class:`tianshou.core.policy` to be optimized. @@ -50,6 +54,8 @@ def REINFORCE(policy): def value_mse(value_function): """ Builds the graph of L2 loss on value functions for, e.g., training critics or DQN. + It creates an placeholder for the target value adds it into the ``managed_placeholders`` + of the ``value_function``. :param value_function: A :class:`tianshou.core.value_function` to be optimized. diff --git a/tianshou/core/opt.py b/tianshou/core/opt.py index dd36cd4..96f263a 100644 --- a/tianshou/core/opt.py +++ b/tianshou/core/opt.py @@ -3,7 +3,7 @@ import tensorflow as tf def DPG(policy, action_value): """ - Constructs the gradient Tensor of `Link deterministic policy gradient `_. + Constructs the gradient Tensor of `deterministic policy gradient `_. :param policy: A :class:`tianshou.core.policy.Deterministic` to be optimized. :param action_value: A :class:`tianshou.core.value_function.ActionValue` to guide the optimization of `policy`. diff --git a/tianshou/core/policy/deterministic.py b/tianshou/core/policy/deterministic.py index 793fcac..5ea3683 100644 --- a/tianshou/core/policy/deterministic.py +++ b/tianshou/core/policy/deterministic.py @@ -146,7 +146,7 @@ class Deterministic(PolicyBase): """ sess = tf.get_default_session() - feed_dict = {self.observation_placeholder: observation} + feed_dict = {self.observation_placeholder: observation}.update(my_feed_dict) action = sess.run(self.action, feed_dict=feed_dict) return action @@ -164,7 +164,7 @@ class Deterministic(PolicyBase): """ sess = tf.get_default_session() - feed_dict = {self.observation_placeholder: observation} + feed_dict = {self.observation_placeholder: observation}.update(my_feed_dict) action = sess.run(self.action_old, feed_dict=feed_dict) return action \ No newline at end of file diff --git a/tianshou/core/random.py b/tianshou/core/random.py index 4a00914..5807670 100644 --- a/tianshou/core/random.py +++ b/tianshou/core/random.py @@ -66,6 +66,11 @@ class GaussianWhiteNoiseProcess(AnnealedGaussianProcess): self.size = size def sample(self): + """ + Draws one sample from the random process. + + :return: A numpy array. The drawn sample. + """ sample = np.random.normal(self.mu, self.current_sigma, self.size) self.n_steps += 1 return sample @@ -102,6 +107,11 @@ class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess): self.reset_states() def sample(self): + """ + Draws one sample from the random process. + + :return: A numpy array. The drawn sample. + """ x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + \ self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size) self.x_prev = x @@ -109,4 +119,7 @@ class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess): return x def reset_states(self): + """ + Reset ``self.x_prev`` to be ``self.x0``. + """ self.x_prev = self.x0 if self.x0 is not None else np.zeros(self.size) diff --git a/tianshou/core/utils.py b/tianshou/core/utils.py index a8fcc9d..749e72f 100644 --- a/tianshou/core/utils.py +++ b/tianshou/core/utils.py @@ -17,11 +17,11 @@ def identify_dependent_variables(tensor, candidate_variables): def get_soft_update_op(update_fraction, including_nets, excluding_nets=None): """ Builds the graph op to softly update the "old net" of policies and value_functions, as suggested in - `Link DDPG `_. It updates the :class:`tf.Variable` s in the old net, + `DDPG `_. It updates the :class:`tf.Variable` s in the old net, :math:`\\theta'` with the :class:`tf.Variable` s in the current network, :math:`\\theta` as - :math:`\\theta' = \tau \\theta + (1 - \tau) \\theta'`. + :math:`\\theta' = \\tau \\theta + (1 - \\tau) \\theta'`. - :param update_fraction: A float in range :math:`[0, 1]`. Corresponding to the :math:`\tau` in the update equation. + :param update_fraction: A float in range :math:`[0, 1]`. Corresponding to the :math:`\\tau` in the update equation. :param including_nets: A list of policies and/or value_functions. All :class:`tf.Variable` s in these networks are included for update. Shared Variables are updated only once in case of layer sharing among the networks. :param excluding_nets: Optional. A list of policies and/or value_functions defaulting to ``None``. diff --git a/tianshou/core/value_function/action_value.py b/tianshou/core/value_function/action_value.py index d28d0c8..599b1da 100644 --- a/tianshou/core/value_function/action_value.py +++ b/tianshou/core/value_function/action_value.py @@ -110,13 +110,13 @@ class ActionValue(ValueFunctionBase): class DQN(ValueFunctionBase): """ Class for the special action value function DQN. Instead of feeding s and a to the network to get a value, - DQN feeds s to the network and gets at the last layer Q(s, *) for all actions under this state. Still, as + DQN feeds s to the network and gets at the last layer Q(s, \*) for all actions under this state. Still, as :class:`ActionValue`, this class still builds the Q(s, a) value Tensor. It can only be used with discrete (and finite) action spaces. :param network_callable: A Python callable returning (action head, value head). When called it builds - the tf graph and returns a Tensor of Q(s, *) on the value head. - :param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder for s in Q(s, *) + the tf graph and returns a Tensor of Q(s, \*) on the value head. + :param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder for s in Q(s, \*) in the network graph. :param has_old_net: A bool defaulting to ``False``. If true this class will create another graph with another set of :class:`tf.Variable` s to be the "old net". The "old net" could be the target networks as in DQN @@ -218,12 +218,12 @@ class DQN(ValueFunctionBase): @property def value_tensor_all_actions(self): - """The Tensor for Q(s, *)""" + """The Tensor for Q(s, \*)""" return self._value_tensor_all_actions def eval_value_all_actions(self, observation, my_feed_dict={}): """ - Evaluate values Q(s, *) in minibatch using the current network. + Evaluate values Q(s, \*) in minibatch using the current network. :param observation: An array-like, of shape (batch_size,) + observation_shape. :param my_feed_dict: Optional. A dict defaulting to empty. @@ -236,7 +236,7 @@ class DQN(ValueFunctionBase): def eval_value_all_actions_old(self, observation, my_feed_dict={}): """ - Evaluate values Q(s, *) in minibatch using the old net. + Evaluate values Q(s, \*) in minibatch using the old net. :param observation: An array-like, of shape (batch_size,) + observation_shape. :param my_feed_dict: Optional. A dict defaulting to empty. diff --git a/tianshou/data/README.md b/tianshou/data/README.md deleted file mode 100644 index e9e6374..0000000 --- a/tianshou/data/README.md +++ /dev/null @@ -1,26 +0,0 @@ -# TODO: - -Notice that we will separate actor and critic, and batch will collect data for optimizing policy while replay will collect data for optimizing critic. - -# Batch - -YouQiaoben - -fix as stated in ppo_example.py - - - -# Replay - -ShihongSong - -a Replay.py file. must have collect() and next_batch() methods for training. - -integrate previous ReplayBuffer codes. - - -# adv_estimate - -YouQiaoben (gae_lambda), ShihongSong(dqn after policy.DQN) - -seems to be direct python functions. also may write it in a functional form. \ No newline at end of file diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py index 010d68c..5818a90 100644 --- a/tianshou/data/advantage_estimation.py +++ b/tianshou/data/advantage_estimation.py @@ -10,11 +10,17 @@ DONE = 3 # TODO: add discount_factor... maybe make it to be a global config? def full_return(buffer, indexes=None): """ - naively compute full return - :param buffer: buffer with property index and data. index determines the current content in `buffer`. - :param indexes: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within - each episode. - :return: dict with key 'return' and value the computed returns corresponding to `index`. + Naively compute full undiscounted return on episodic data, :math:`G_t = \sum_{t=0}^T r_t`. + This function will print a warning when some of the episodes + in ``buffer`` has not yet terminated. + + :param buffer: A :class:`tianshou.data.data_buffer`. + :param indexes: Optional. Indexes of data points on which the full return should be computed. + If not set, it defaults to all the data points in ``buffer``. + Note that if it's the index of a sampled minibatch, it doesn't have to be in order within + each episode. + + :return: A dict with key 'return' and value the computed returns corresponding to ``indexes``. """ indexes = indexes or buffer.index raw_data = buffer.data @@ -46,27 +52,20 @@ def full_return(buffer, indexes=None): return {'return': returns} -class gae_lambda: - """ - Generalized Advantage Estimation (Schulman, 15) to compute advantage - """ - def __init__(self, T, value_function): - self.T = T - self.value_function = value_function - - def __call__(self, buffer, index=None): - """ - :param buffer: buffer with property index and data. index determines the current content in `buffer`. - :param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within - each episode. - :return: dict with key 'advantage' and value the computed advantages corresponding to `index`. - """ - pass - - class nstep_return: """ - compute the n-step return from n-step rewards and bootstrapped value function + Compute the n-step return from n-step rewards and bootstrapped state value function V(s), + :math:`V(s_t) = r_t + \gamma r_{t+1} + ... + \gamma^{n-1} r_{t+n-1} + \gamma^n V(s_{t+n})`. + + :param n: An int. The number of steps to lookahead, where :math:`n=1` will directly apply V(s) to + the next observation, as in the above equation. + :param value_function: A :class:`tianshou.core.value_function.StateValue`. The V(s) as in the + above equation + :param return_advantage: Optional. A bool defaulting to ``False``. If ``True`` than this callable + also returns the advantage function + :math:`A(s_t) = r_t + \gamma r_{t+1} + ... + \gamma^{n-1} r_{t+n-1} + \gamma^n V(s_{t+n}) - V(s_t)` when called. + :param discount_factor: Optional. A float in range :math:`[0, 1]` defaulting to 0.99. The discount + factor :math:`\gamma` as in the above equation. """ def __init__(self, n, value_function, return_advantage=False, discount_factor=0.99): self.n = n @@ -76,10 +75,15 @@ class nstep_return: def __call__(self, buffer, indexes=None): """ - :param buffer: buffer with property index and data. index determines the current content in `buffer`. - :param indexes: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within - each episode. - :return: dict with key 'return' and value the computed returns corresponding to `index`. + :param buffer: A :class:`tianshou.data.data_buffer`. + :param indexes: Optional. Indexes of data points on which the specified return should be computed. + If not set, it defaults to all the data points in ``buffer``. + Note that if it's the index of a sampled minibatch, it doesn't have to be in order within + each episode. + + :return: A dict with key 'return' and value the computed returns corresponding to ``indexes``. + If ``return_advantage`` set to ``True`` then also a key 'advantage' and value the corresponding + advantages. """ indexes = indexes or buffer.index episodes = buffer.data @@ -125,7 +129,16 @@ class nstep_return: class ddpg_return: """ - compute the return as in DDPG. this seems to have to be special + Compute the return as in `DDPG `_, + :math:`G_t = r_t + \gamma Q'(s_{t+1}, \mu'(s_{t+1}))`, where :math:`Q'` and :math:`\mu'` are the + target networks. + + :param actor: A :class:`tianshou.core.policy.Deterministic`. A deterministic policy. + :param critic: A :class:`tianshou.core.value_function.ActionValue`. An action value function Q(s, a). + :param use_target_network: Optional. A bool defaulting to ``True``. Whether to use the target networks + in the above equation. + :param discount_factor: Optional. A float in range :math:`[0, 1]` defaulting to 0.99. The discount + factor :math:`\gamma` as in the above equation. """ def __init__(self, actor, critic, use_target_network=True, discount_factor=0.99): self.actor = actor @@ -135,10 +148,13 @@ class ddpg_return: def __call__(self, buffer, indexes=None): """ - :param buffer: buffer with property index and data. index determines the current content in `buffer`. - :param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within - each episode. - :return: dict with key 'return' and value the computed returns corresponding to `index`. + :param buffer: A :class:`tianshou.data.data_buffer`. + :param indexes: Optional. Indexes of data points on which the specified return should be computed. + If not set, it defaults to all the data points in ``buffer``. + Note that if it's the index of a sampled minibatch, it doesn't have to be in order within + each episode. + + :return: A dict with key 'return' and value the computed returns corresponding to ``indexes``. """ indexes = indexes or buffer.index episodes = buffer.data @@ -175,7 +191,17 @@ class ddpg_return: class nstep_q_return: """ - compute the n-step return for Q-learning targets + Compute the n-step return for Q-learning targets, + :math:`G_t = r_t + \gamma \max_a Q'(s_{t+1}, a)`. + + :param n: An int. The number of steps to lookahead, where :math:`n=1` will directly apply :math:`Q'(s, \*)` to + the next observation, as in the above equation. + :param action_value: A :class:`tianshou.core.value_function.DQN`. The :math:`Q'(s, \*)` as in the + above equation. + :param use_target_network: Optional. A bool defaulting to ``True``. Whether to use the target networks + in the above equation. + :param discount_factor: Optional. A float in range :math:`[0, 1]` defaulting to 0.99. The discount + factor :math:`\gamma` as in the above equation. """ def __init__(self, n, action_value, use_target_network=True, discount_factor=0.99): self.n = n @@ -185,10 +211,13 @@ class nstep_q_return: def __call__(self, buffer, indexes=None): """ - :param buffer: buffer with property index and data. index determines the current content in `buffer`. - :param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within - each episode. - :return: dict with key 'return' and value the computed returns corresponding to `index`. + :param buffer: A :class:`tianshou.data.data_buffer`. + :param indexes: Optional. Indexes of data points on which the full return should be computed. + If not set, it defaults to all the data points in ``buffer``. + Note that if it's the index of a sampled minibatch, it doesn't have to be in order within + each episode. + + :return: A dict with key 'return' and value the computed returns corresponding to ``indexes``. """ indexes = indexes or buffer.index episodes = buffer.data @@ -224,3 +253,28 @@ class nstep_q_return: returns.append([]) return {'return': returns} + + +class gae_lambda: + """ + Generalized Advantage Estimation (Schulman, 15) to compute advantage. To be implemented. + """ + def __init__(self, T, value_function): + self.T = T + self.value_function = value_function + + raise NotImplementedError() + + def __call__(self, buffer, indexes=None): + """ + To be implemented + + :param buffer: A :class:`tianshou.data.data_buffer`. + :param indexes: Optional. Indexes of data points on which the full return should be computed. + If not set, it defaults to all the data points in ``buffer``. + Note that if it's the index of a sampled minibatch, it doesn't have to be in order within + each episode. + + :return: A dict with key 'advantage' and value the computed advantages corresponding to ``indexes``. + """ + raise NotImplementedError() diff --git a/tianshou/data/data_buffer/__init__.py b/tianshou/data/data_buffer/__init__.py index 0deb77e..36eccd3 100644 --- a/tianshou/data/data_buffer/__init__.py +++ b/tianshou/data/data_buffer/__init__.py @@ -1,4 +1,4 @@ -from os.path import dirname, basename, isfile -import glob -modules = glob.glob(dirname(__file__)+"/*.py") -__all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')] +# from os.path import dirname, basename, isfile +# import glob +# modules = glob.glob(dirname(__file__)+"/*.py") +# __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')] diff --git a/tianshou/data/data_buffer/base.py b/tianshou/data/data_buffer/base.py index db131f6..49c3ddc 100644 --- a/tianshou/data/data_buffer/base.py +++ b/tianshou/data/data_buffer/base.py @@ -2,25 +2,31 @@ class DataBufferBase(object): """ - base class for data buffer, including replay buffer as in DQN and batched dataset as in on-policy algos + Base class for data buffer, including replay buffer as used by DQN + and batched dataset as used by on-policy algorithms. + Our data buffer adopts a memory-efficient implementation where raw data are always stored in a + sequential manner, and an additional set of index is used to indicate the valid data points + in the data buffer. + + The raw data and index are both organized in a two-level architecture as lists of lists, where + the high-level lists correspond to episodes and low-level lists correspond to the data within + each episode. + + Mandatory methods for a data buffer class are: + + - :func:`add`. It adds one timestep of data to the data buffer. + + - :func:`clear`. It empties the data buffer. + + - :func:`sample`. It samples one minibatch of data and returns the index of the sampled data\ + points, not the raw data. """ def add(self, frame): raise NotImplementedError() def clear(self): + """Empties the data buffer, usually used in batch set but not in replay buffer.""" raise NotImplementedError() def sample(self, batch_size): - prob_episode = np.array(self.index_lengths) * 1. / self.size - num_episodes = len(self.index) - sampled_index = [[] for _ in range(num_episodes)] - - for _ in range(batch_size): - # sample which episode - sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode)) - - # sample which data point within the sampled episode - sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i])) - sampled_index[sampled_episode_i].append(sampled_frame_i) - - return sampled_index + raise NotImplementedError() diff --git a/tianshou/data/data_buffer/batch_set.py b/tianshou/data/data_buffer/batch_set.py index a3e43ac..b4d8044 100644 --- a/tianshou/data/data_buffer/batch_set.py +++ b/tianshou/data/data_buffer/batch_set.py @@ -9,12 +9,19 @@ ACTION = 1 REWARD = 2 DONE = 3 + class BatchSet(DataBufferBase): """ - class for batched dataset as used in on-policy algos + Class for batched dataset as used in on-policy algorithms, where a batch of data is first collected + with the current policy, several optimization steps are then conducted on this batch of data and the + data are then discarded and collected again. + + :param nstep: An int defaulting to 1. The number of timesteps to lookahead for temporal difference computation. + Only continuous data pieces longer than this number or already terminated ones are + considered valid data points. """ - def __init__(self, nstep=None): - self.nstep = nstep or 1 # RL has to look ahead at least one timestep + def __init__(self, nstep=1): + self.nstep = nstep # RL has to look ahead at least one timestep self.data = [[]] self.index = [[]] @@ -25,6 +32,11 @@ class BatchSet(DataBufferBase): self.index_lengths = [0] # for sampling def add(self, frame): + """ + Adds one frame of data to the buffer. + + :param frame: A tuple of (observation, action, reward, done_flag). + """ self.data[-1].append(frame) has_enough_frames = len(self.data[-1]) > self.nstep @@ -48,6 +60,9 @@ class BatchSet(DataBufferBase): self.index_lengths[-1] += 1 def clear(self): + """ + Empties the data buffer and prepares to collect a new batch of data. + """ del self.data del self.index del self.index_lengths @@ -61,7 +76,17 @@ class BatchSet(DataBufferBase): self.index_lengths = [0] def sample(self, batch_size): - # TODO: move unified properties and methods to base. but this depends on how to deal with nstep + """ + Performs uniform random sampling on ``self.index``. For simplicity, we do random sampling with replacement + for now with time O(``batch_size``). Fastest sampling without replacement seems to have to be of time + O(``batch_size`` * log(num_episodes)). + + :param batch_size: An int. The size of the minibatch. + + :return: A list of list of the sampled indexes. Episodes without sampled data points + correspond to empty sub-lists. + """ + # TODO: move unified properties and methods to base. but this may depend on how to deal with nstep prob_episode = np.array(self.index_lengths) * 1. / self.size num_episodes = len(self.index) @@ -78,6 +103,14 @@ class BatchSet(DataBufferBase): return sampled_index def statistics(self, discount_factor=0.99): + """ + Computes and prints out the statistics (e.g., discounted returns, undiscounted returns) in the batch set. + This is useful when policies are optimized by on-policy algorithms, so the current data in + the batch set directly reflect the performance of the current policy. + + :param discount_factor: Optional. A float in range :math:`[0, 1]` defaulting to 0.99. The discount + factor to compute discounted returns. + """ returns = [] undiscounted_returns = [] diff --git a/tianshou/data/data_buffer/replay_buffer_base.py b/tianshou/data/data_buffer/replay_buffer_base.py index dd437ae..c0539be 100644 --- a/tianshou/data/data_buffer/replay_buffer_base.py +++ b/tianshou/data/data_buffer/replay_buffer_base.py @@ -2,11 +2,10 @@ from .base import DataBufferBase class ReplayBufferBase(DataBufferBase): """ - base class for replay buffer. + Base class for replay buffer. + Compared to :class:`DataBufferBase`, it has an additional method :func:`remove`, + which removes extra data points when the size of the data buffer exceeds capacity. + Besides, as the practice of using such replay buffer, it's never :func:`clear` ed. """ def remove(self): - """ - when size exceeds capacity, removes extra data points - :return: - """ raise NotImplementedError() diff --git a/tianshou/data/data_buffer/vanilla.py b/tianshou/data/data_buffer/vanilla.py index 5feb55e..005327a 100644 --- a/tianshou/data/data_buffer/vanilla.py +++ b/tianshou/data/data_buffer/vanilla.py @@ -8,18 +8,21 @@ ACTION = 1 REWARD = 2 DONE = 3 + # TODO: valid data points could be less than `nstep` timesteps. Check priority replay paper! class VanillaReplayBuffer(ReplayBufferBase): """ - vanilla replay buffer as used in (Mnih, et al., 2015). - Frames are always continuous in temporal order. They are only removed from the beginning. This continuity - in `self.data` could be exploited, but only in vanilla replay buffer. + Class for vanilla replay buffer as used in (Mnih, et al., 2015). + Frames are always continuous in temporal order. They are only removed from the beginning + and added at the tail. + This continuity in `self.data` could be exploited only in vanilla replay buffer. + + :param capacity: An int. The capacity of the buffer. + :param nstep: An int defaulting to 1. The number of timesteps to lookahead for temporal difference computation. + Only continuous data pieces longer than this number or already terminated ones are + considered valid data points. """ def __init__(self, capacity, nstep=1): - """ - :param capacity: int. capacity of the buffer. - :param nstep: int. number of timesteps to lookahead for temporal difference - """ assert capacity > 0 self.capacity = int(capacity) self.nstep = nstep @@ -34,8 +37,9 @@ class VanillaReplayBuffer(ReplayBufferBase): def add(self, frame): """ - add one frame to the buffer. - :param frame: tuple, (observation, action, reward, done_flag). + Adds one frame of data to the buffer. + + :param frame: A tuple of (observation, action, reward, done_flag). """ self.data[-1].append(frame) @@ -65,17 +69,17 @@ class VanillaReplayBuffer(ReplayBufferBase): def remove(self): """ - remove data until `self.size` <= `self.capacity` + Removes data from the buffer until ``self.size <= self.capacity``. """ if self.size: while self.size > self.capacity: - self.remove_oldest() + self._remove_oldest() else: logging.warning('Attempting to remove from empty buffer!') - def remove_oldest(self): + def _remove_oldest(self): """ - remove the oldest data point, in this case, just the oldest frame. Empty episodes are also removed + Removes the oldest data point, in this case, just the oldest frame. Empty episodes are also removed if resulted from removal. """ self.index[0].pop() # note that all index of frames in the first episode are shifted forward by 1 @@ -98,12 +102,14 @@ class VanillaReplayBuffer(ReplayBufferBase): def sample(self, batch_size): """ - uniform random sampling on `self.index`. For simplicity, we do random sampling with replacement - for now with time O(`batch_size`). Fastest sampling without replacement seems to have to be of time - O(`batch_size` * log(num_episodes)). - :param batch_size: int. - :return: sampled index, same structure as `self.index`. Episodes without sampled data points - correspond to empty sub-lists. + Performs uniform random sampling on ``self.index``. For simplicity, we do random sampling with replacement + for now with time O(``batch_size``). Fastest sampling without replacement seems to have to be of time + O(``batch_size`` * log(num_episodes)). + + :param batch_size: An int. The size of the minibatch. + + :return: A list of list of the sampled indexes. Episodes without sampled data points + correspond to empty sub-lists. """ prob_episode = np.array(self.index_lengths) * 1. / self.size num_episodes = len(self.index) diff --git a/tianshou/data/data_collector.py b/tianshou/data/data_collector.py index 6469e6a..c70b3e5 100644 --- a/tianshou/data/data_collector.py +++ b/tianshou/data/data_collector.py @@ -5,18 +5,24 @@ import itertools from .data_buffer.replay_buffer_base import ReplayBufferBase from .data_buffer.batch_set import BatchSet from .utils import internal_key_match +from ..core.policy.deterministic import Deterministic class DataCollector(object): """ A utility class to manage the data flow during the interaction between the policy and the environment. - It stores data into ``data_buffer``, processes the reward signals and returns the feed_dict for tf graph running. + It stores data into ``data_buffer``, processes the reward signals and returns the feed_dict for + tf graph running. - :param env: - :param policy: - :param data_buffer: - :param process_functions: - :param managed_networks: + :param env: An environment. + :param policy: A :class:`tianshou.core.policy`. + :param data_buffer: A :class:`tianshou.data.data_buffer`. + :param process_functions: A list of callables in :mod:`tianshou.data.advantage_estimation` + to process rewards. + :param managed_networks: A list of networks of :class:`tianshou.core.policy` and/or + :class:`tianshou.core.value_function`. The networks you want this class to manage. This class + will automatically generate the feed_dict for all the placeholders in the ``managed_placeholders`` + of all networks in this list. """ def __init__(self, env, policy, data_buffer, process_functions, managed_networks): self.env = env @@ -42,6 +48,23 @@ class DataCollector(object): self.step_count_this_episode = 0 def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, auto_clear=True, episode_cutoff=None): + """ + Collect data in the environment using ``self.policy``. + + :param num_timesteps: An int specifying the number of timesteps to act. It defaults to 0 and either + ``num_timesteps`` or ``num_episodes`` could be set but not both. + :param num_episodes: An int specifying the number of episodes to act. It defaults to 0 and either + ``num_timesteps`` or ``num_episodes`` could be set but not both. + :param my_feed_dict: Optional. A dict defaulting to empty. + Specifies placeholders such as dropout and batch_norm except observation and action. + :param auto_clear: Optional. A bool defaulting to ``True``. If ``True`` then this method clears the + ``self.data_buffer`` if ``self.data_buffer`` is an instance of + :class:`tianshou.data.data_buffer.BatchSet.` and does nothing if it's not that instance. + If set to ``False`` then the aforementioned auto clearing behavior is disabled. + :param episode_cutoff: Optional. An int. The maximum number of timesteps in one episode. This is + useful when the environment has no terminal states or a single episode could be prohibitively long. + If set than all episodes are forced to stop beyond this number to timesteps. + """ assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\ "One and only one collection number specification permitted!" @@ -89,7 +112,20 @@ class DataCollector(object): for processor in self.process_functions: self.data.update(processor(self.data_buffer)) - def next_batch(self, batch_size, standardize_advantage=None): + return + + def next_batch(self, batch_size, standardize_advantage=True): + """ + Constructs and returns the feed_dict of data to be used with ``sess.run``. + + :param batch_size: An int. The size of one minibatch. + :param standardize_advantage: Optional. A bool but defaulting to ``True``. + If ``True``, then this method standardize advantages if advantage is required by the networks. + If ``False`` then this method will never standardize advantage. + + :return: A dict in the format of conventional feed_dict in tf, with keys the placeholders and + values the numpy arrays. + """ sampled_index = self.data_buffer.sample(batch_size) if self.process_mode == 'sample': for processor in self.process_functions: @@ -128,8 +164,7 @@ class DataCollector(object): else: raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name))) - auto_standardize = (standardize_advantage is None) and self.require_advantage - if standardize_advantage or auto_standardize: + if standardize_advantage: if self.require_advantage: advantage_value = feed_dict[self.required_placeholders['advantage']] advantage_mean = np.mean(advantage_value) @@ -140,10 +175,21 @@ class DataCollector(object): return feed_dict - def denoise_action(self, feed_dict): + def denoise_action(self, feed_dict, my_feed_dict={}): + """ + Recompute the actions of deterministic policies without exploration noise, hence denoising. + It modifies ``feed_dict`` **in place** and has no return value. + This is useful in, e.g., DDPG since the stored action in ``self.data_buffer`` is the sampled + action with additional exploration noise. + :param feed_dict: A dict. It has to be the dict returned by :func:`next_batch` by this class. + :param my_feed_dict: Optional. A dict defaulting to empty. + Specifies placeholders such as dropout and batch_norm except observation and action. + """ + assert isinstance(self.policy, Deterministic), 'denoise_action() could only be called' \ + 'with deterministic policies' observation = feed_dict[self.required_placeholders['observation']] - action_mean = self.policy.eval_action(observation) + action_mean = self.policy.eval_action(observation, my_feed_dict) feed_dict[self.required_placeholders['action']] = action_mean return diff --git a/tianshou/data/tester.py b/tianshou/data/tester.py index 82c7ee2..46fea1b 100644 --- a/tianshou/data/tester.py +++ b/tianshou/data/tester.py @@ -5,10 +5,31 @@ import logging import numpy as np -def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_factor=0.99, seed=0, episode_cutoff=None): +def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, + discount_factor=0.99, seed=0, episode_cutoff=None): + """ + Tests the policy in the environment and record and prints out the performance. This is useful when the policy + is trained with off-policy algorithms and thus the rewards in the data buffer does not reflect the + performance of the current policy. + :param policy: A :class:`tianshou.core.policy`. The current policy being optimized. + :param env: An environment. + :param num_timesteps: An int specifying the number of timesteps to test the policy. + It defaults to 0 and either + ``num_timesteps`` or ``num_episodes`` could be set but not both. + :param num_episodes: An int specifying the number of episodes to test the policy. + It defaults to 0 and either + ``num_timesteps`` or ``num_episodes`` could be set but not both. + :param discount_factor: Optional. A float in range :math:`[0, 1]` defaulting to 0.99. The discount + factor to compute discounted returns. + :param seed: An non-negative int. The seed to seed the environment as ``env.seed(seed)``. + :param episode_cutoff: Optional. An int. The maximum number of timesteps in one episode. This is + useful when the environment has no terminal states or a single episode could be prohibitively long. + If set than all episodes are forced to stop beyond this number to timesteps. + """ assert sum([num_episodes > 0, num_timesteps > 0]) == 1, \ 'One and only one collection number specification permitted!' + assert seed >= 0 # make another env as the original is for training data collection env_id = env.spec.id @@ -82,4 +103,4 @@ def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_fa print('Mean undiscounted return: {}'.format(mean_undiscounted_return)) # clear scene - env_.close() \ No newline at end of file + env_.close()