finish all API docs, first version.
This commit is contained in:
parent
8c108174b6
commit
5f979caf58
@ -4,9 +4,11 @@ import tensorflow as tf
|
||||
def ppo_clip(policy, clip_param):
|
||||
"""
|
||||
Builds the graph of clipped loss :math:`L^{CLIP}` as in the
|
||||
`Link PPO paper <https://arxiv.org/pdf/1707.06347.pdf>`_, which is basically
|
||||
:math:`-\min(r_t(\\theta)Aˆt, clip(r_t(\\theta), 1 - \epsilon, 1 + \epsilon)Aˆt)`.
|
||||
`PPO paper <https://arxiv.org/pdf/1707.06347.pdf>`_, which is basically
|
||||
:math:`-\min(r_t(\\theta)A_t, \mathrm{clip}(r_t(\\theta), 1 - \epsilon, 1 + \epsilon)A_t)`.
|
||||
We minimize the objective instead of maximizing, hence the leading negative sign.
|
||||
It creates an action placeholder and an advantage placeholder and adds into the ``managed_placeholders``
|
||||
of the ``policy``.
|
||||
|
||||
:param policy: A :class:`tianshou.core.policy` to be optimized.
|
||||
:param clip param: A float or Tensor of type float. The :math:`\epsilon` in the loss equation.
|
||||
@ -29,8 +31,10 @@ def ppo_clip(policy, clip_param):
|
||||
def REINFORCE(policy):
|
||||
"""
|
||||
Builds the graph of the loss function as used in vanilla policy gradient algorithms, i.e., REINFORCE.
|
||||
The loss is basically :math:`\log \pi(a|s) A^t`.
|
||||
The loss is basically :math:`\log \pi(a|s) A_t`.
|
||||
We minimize the objective instead of maximizing, hence the leading negative sign.
|
||||
It creates an action placeholder and an advantage placeholder and adds into the ``managed_placeholders``
|
||||
of the ``policy``.
|
||||
|
||||
:param policy: A :class:`tianshou.core.policy` to be optimized.
|
||||
|
||||
@ -50,6 +54,8 @@ def REINFORCE(policy):
|
||||
def value_mse(value_function):
|
||||
"""
|
||||
Builds the graph of L2 loss on value functions for, e.g., training critics or DQN.
|
||||
It creates an placeholder for the target value adds it into the ``managed_placeholders``
|
||||
of the ``value_function``.
|
||||
|
||||
:param value_function: A :class:`tianshou.core.value_function` to be optimized.
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import tensorflow as tf
|
||||
|
||||
def DPG(policy, action_value):
|
||||
"""
|
||||
Constructs the gradient Tensor of `Link deterministic policy gradient <https://arxiv.org/pdf/1509.02971.pdf>`_.
|
||||
Constructs the gradient Tensor of `deterministic policy gradient <https://arxiv.org/pdf/1509.02971.pdf>`_.
|
||||
|
||||
:param policy: A :class:`tianshou.core.policy.Deterministic` to be optimized.
|
||||
:param action_value: A :class:`tianshou.core.value_function.ActionValue` to guide the optimization of `policy`.
|
||||
|
||||
@ -146,7 +146,7 @@ class Deterministic(PolicyBase):
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
|
||||
feed_dict = {self.observation_placeholder: observation}
|
||||
feed_dict = {self.observation_placeholder: observation}.update(my_feed_dict)
|
||||
action = sess.run(self.action, feed_dict=feed_dict)
|
||||
|
||||
return action
|
||||
@ -164,7 +164,7 @@ class Deterministic(PolicyBase):
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
|
||||
feed_dict = {self.observation_placeholder: observation}
|
||||
feed_dict = {self.observation_placeholder: observation}.update(my_feed_dict)
|
||||
action = sess.run(self.action_old, feed_dict=feed_dict)
|
||||
|
||||
return action
|
||||
@ -66,6 +66,11 @@ class GaussianWhiteNoiseProcess(AnnealedGaussianProcess):
|
||||
self.size = size
|
||||
|
||||
def sample(self):
|
||||
"""
|
||||
Draws one sample from the random process.
|
||||
|
||||
:return: A numpy array. The drawn sample.
|
||||
"""
|
||||
sample = np.random.normal(self.mu, self.current_sigma, self.size)
|
||||
self.n_steps += 1
|
||||
return sample
|
||||
@ -102,6 +107,11 @@ class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess):
|
||||
self.reset_states()
|
||||
|
||||
def sample(self):
|
||||
"""
|
||||
Draws one sample from the random process.
|
||||
|
||||
:return: A numpy array. The drawn sample.
|
||||
"""
|
||||
x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + \
|
||||
self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size)
|
||||
self.x_prev = x
|
||||
@ -109,4 +119,7 @@ class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess):
|
||||
return x
|
||||
|
||||
def reset_states(self):
|
||||
"""
|
||||
Reset ``self.x_prev`` to be ``self.x0``.
|
||||
"""
|
||||
self.x_prev = self.x0 if self.x0 is not None else np.zeros(self.size)
|
||||
|
||||
@ -17,11 +17,11 @@ def identify_dependent_variables(tensor, candidate_variables):
|
||||
def get_soft_update_op(update_fraction, including_nets, excluding_nets=None):
|
||||
"""
|
||||
Builds the graph op to softly update the "old net" of policies and value_functions, as suggested in
|
||||
`Link DDPG <https://arxiv.org/pdf/1509.02971.pdf>`_. It updates the :class:`tf.Variable` s in the old net,
|
||||
`DDPG <https://arxiv.org/pdf/1509.02971.pdf>`_. It updates the :class:`tf.Variable` s in the old net,
|
||||
:math:`\\theta'` with the :class:`tf.Variable` s in the current network, :math:`\\theta` as
|
||||
:math:`\\theta' = \tau \\theta + (1 - \tau) \\theta'`.
|
||||
:math:`\\theta' = \\tau \\theta + (1 - \\tau) \\theta'`.
|
||||
|
||||
:param update_fraction: A float in range :math:`[0, 1]`. Corresponding to the :math:`\tau` in the update equation.
|
||||
:param update_fraction: A float in range :math:`[0, 1]`. Corresponding to the :math:`\\tau` in the update equation.
|
||||
:param including_nets: A list of policies and/or value_functions. All :class:`tf.Variable` s in these networks
|
||||
are included for update. Shared Variables are updated only once in case of layer sharing among the networks.
|
||||
:param excluding_nets: Optional. A list of policies and/or value_functions defaulting to ``None``.
|
||||
|
||||
@ -110,13 +110,13 @@ class ActionValue(ValueFunctionBase):
|
||||
class DQN(ValueFunctionBase):
|
||||
"""
|
||||
Class for the special action value function DQN. Instead of feeding s and a to the network to get a value,
|
||||
DQN feeds s to the network and gets at the last layer Q(s, *) for all actions under this state. Still, as
|
||||
DQN feeds s to the network and gets at the last layer Q(s, \*) for all actions under this state. Still, as
|
||||
:class:`ActionValue`, this class still builds the Q(s, a) value Tensor. It can only be used with discrete
|
||||
(and finite) action spaces.
|
||||
|
||||
:param network_callable: A Python callable returning (action head, value head). When called it builds
|
||||
the tf graph and returns a Tensor of Q(s, *) on the value head.
|
||||
:param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder for s in Q(s, *)
|
||||
the tf graph and returns a Tensor of Q(s, \*) on the value head.
|
||||
:param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder for s in Q(s, \*)
|
||||
in the network graph.
|
||||
:param has_old_net: A bool defaulting to ``False``. If true this class will create another graph with another
|
||||
set of :class:`tf.Variable` s to be the "old net". The "old net" could be the target networks as in DQN
|
||||
@ -218,12 +218,12 @@ class DQN(ValueFunctionBase):
|
||||
|
||||
@property
|
||||
def value_tensor_all_actions(self):
|
||||
"""The Tensor for Q(s, *)"""
|
||||
"""The Tensor for Q(s, \*)"""
|
||||
return self._value_tensor_all_actions
|
||||
|
||||
def eval_value_all_actions(self, observation, my_feed_dict={}):
|
||||
"""
|
||||
Evaluate values Q(s, *) in minibatch using the current network.
|
||||
Evaluate values Q(s, \*) in minibatch using the current network.
|
||||
|
||||
:param observation: An array-like, of shape (batch_size,) + observation_shape.
|
||||
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||
@ -236,7 +236,7 @@ class DQN(ValueFunctionBase):
|
||||
|
||||
def eval_value_all_actions_old(self, observation, my_feed_dict={}):
|
||||
"""
|
||||
Evaluate values Q(s, *) in minibatch using the old net.
|
||||
Evaluate values Q(s, \*) in minibatch using the old net.
|
||||
|
||||
:param observation: An array-like, of shape (batch_size,) + observation_shape.
|
||||
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||
|
||||
@ -1,26 +0,0 @@
|
||||
# TODO:
|
||||
|
||||
Notice that we will separate actor and critic, and batch will collect data for optimizing policy while replay will collect data for optimizing critic.
|
||||
|
||||
# Batch
|
||||
|
||||
YouQiaoben
|
||||
|
||||
fix as stated in ppo_example.py
|
||||
|
||||
|
||||
|
||||
# Replay
|
||||
|
||||
ShihongSong
|
||||
|
||||
a Replay.py file. must have collect() and next_batch() methods for training.
|
||||
|
||||
integrate previous ReplayBuffer codes.
|
||||
|
||||
|
||||
# adv_estimate
|
||||
|
||||
YouQiaoben (gae_lambda), ShihongSong(dqn after policy.DQN)
|
||||
|
||||
seems to be direct python functions. also may write it in a functional form.
|
||||
@ -10,11 +10,17 @@ DONE = 3
|
||||
# TODO: add discount_factor... maybe make it to be a global config?
|
||||
def full_return(buffer, indexes=None):
|
||||
"""
|
||||
naively compute full return
|
||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||
:param indexes: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||
each episode.
|
||||
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||
Naively compute full undiscounted return on episodic data, :math:`G_t = \sum_{t=0}^T r_t`.
|
||||
This function will print a warning when some of the episodes
|
||||
in ``buffer`` has not yet terminated.
|
||||
|
||||
:param buffer: A :class:`tianshou.data.data_buffer`.
|
||||
:param indexes: Optional. Indexes of data points on which the full return should be computed.
|
||||
If not set, it defaults to all the data points in ``buffer``.
|
||||
Note that if it's the index of a sampled minibatch, it doesn't have to be in order within
|
||||
each episode.
|
||||
|
||||
:return: A dict with key 'return' and value the computed returns corresponding to ``indexes``.
|
||||
"""
|
||||
indexes = indexes or buffer.index
|
||||
raw_data = buffer.data
|
||||
@ -46,27 +52,20 @@ def full_return(buffer, indexes=None):
|
||||
return {'return': returns}
|
||||
|
||||
|
||||
class gae_lambda:
|
||||
"""
|
||||
Generalized Advantage Estimation (Schulman, 15) to compute advantage
|
||||
"""
|
||||
def __init__(self, T, value_function):
|
||||
self.T = T
|
||||
self.value_function = value_function
|
||||
|
||||
def __call__(self, buffer, index=None):
|
||||
"""
|
||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||
each episode.
|
||||
:return: dict with key 'advantage' and value the computed advantages corresponding to `index`.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class nstep_return:
|
||||
"""
|
||||
compute the n-step return from n-step rewards and bootstrapped value function
|
||||
Compute the n-step return from n-step rewards and bootstrapped state value function V(s),
|
||||
:math:`V(s_t) = r_t + \gamma r_{t+1} + ... + \gamma^{n-1} r_{t+n-1} + \gamma^n V(s_{t+n})`.
|
||||
|
||||
:param n: An int. The number of steps to lookahead, where :math:`n=1` will directly apply V(s) to
|
||||
the next observation, as in the above equation.
|
||||
:param value_function: A :class:`tianshou.core.value_function.StateValue`. The V(s) as in the
|
||||
above equation
|
||||
:param return_advantage: Optional. A bool defaulting to ``False``. If ``True`` than this callable
|
||||
also returns the advantage function
|
||||
:math:`A(s_t) = r_t + \gamma r_{t+1} + ... + \gamma^{n-1} r_{t+n-1} + \gamma^n V(s_{t+n}) - V(s_t)` when called.
|
||||
:param discount_factor: Optional. A float in range :math:`[0, 1]` defaulting to 0.99. The discount
|
||||
factor :math:`\gamma` as in the above equation.
|
||||
"""
|
||||
def __init__(self, n, value_function, return_advantage=False, discount_factor=0.99):
|
||||
self.n = n
|
||||
@ -76,10 +75,15 @@ class nstep_return:
|
||||
|
||||
def __call__(self, buffer, indexes=None):
|
||||
"""
|
||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||
:param indexes: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||
each episode.
|
||||
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||
:param buffer: A :class:`tianshou.data.data_buffer`.
|
||||
:param indexes: Optional. Indexes of data points on which the specified return should be computed.
|
||||
If not set, it defaults to all the data points in ``buffer``.
|
||||
Note that if it's the index of a sampled minibatch, it doesn't have to be in order within
|
||||
each episode.
|
||||
|
||||
:return: A dict with key 'return' and value the computed returns corresponding to ``indexes``.
|
||||
If ``return_advantage`` set to ``True`` then also a key 'advantage' and value the corresponding
|
||||
advantages.
|
||||
"""
|
||||
indexes = indexes or buffer.index
|
||||
episodes = buffer.data
|
||||
@ -125,7 +129,16 @@ class nstep_return:
|
||||
|
||||
class ddpg_return:
|
||||
"""
|
||||
compute the return as in DDPG. this seems to have to be special
|
||||
Compute the return as in `DDPG <https://arxiv.org/pdf/1509.02971.pdf>`_,
|
||||
:math:`G_t = r_t + \gamma Q'(s_{t+1}, \mu'(s_{t+1}))`, where :math:`Q'` and :math:`\mu'` are the
|
||||
target networks.
|
||||
|
||||
:param actor: A :class:`tianshou.core.policy.Deterministic`. A deterministic policy.
|
||||
:param critic: A :class:`tianshou.core.value_function.ActionValue`. An action value function Q(s, a).
|
||||
:param use_target_network: Optional. A bool defaulting to ``True``. Whether to use the target networks
|
||||
in the above equation.
|
||||
:param discount_factor: Optional. A float in range :math:`[0, 1]` defaulting to 0.99. The discount
|
||||
factor :math:`\gamma` as in the above equation.
|
||||
"""
|
||||
def __init__(self, actor, critic, use_target_network=True, discount_factor=0.99):
|
||||
self.actor = actor
|
||||
@ -135,10 +148,13 @@ class ddpg_return:
|
||||
|
||||
def __call__(self, buffer, indexes=None):
|
||||
"""
|
||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||
each episode.
|
||||
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||
:param buffer: A :class:`tianshou.data.data_buffer`.
|
||||
:param indexes: Optional. Indexes of data points on which the specified return should be computed.
|
||||
If not set, it defaults to all the data points in ``buffer``.
|
||||
Note that if it's the index of a sampled minibatch, it doesn't have to be in order within
|
||||
each episode.
|
||||
|
||||
:return: A dict with key 'return' and value the computed returns corresponding to ``indexes``.
|
||||
"""
|
||||
indexes = indexes or buffer.index
|
||||
episodes = buffer.data
|
||||
@ -175,7 +191,17 @@ class ddpg_return:
|
||||
|
||||
class nstep_q_return:
|
||||
"""
|
||||
compute the n-step return for Q-learning targets
|
||||
Compute the n-step return for Q-learning targets,
|
||||
:math:`G_t = r_t + \gamma \max_a Q'(s_{t+1}, a)`.
|
||||
|
||||
:param n: An int. The number of steps to lookahead, where :math:`n=1` will directly apply :math:`Q'(s, \*)` to
|
||||
the next observation, as in the above equation.
|
||||
:param action_value: A :class:`tianshou.core.value_function.DQN`. The :math:`Q'(s, \*)` as in the
|
||||
above equation.
|
||||
:param use_target_network: Optional. A bool defaulting to ``True``. Whether to use the target networks
|
||||
in the above equation.
|
||||
:param discount_factor: Optional. A float in range :math:`[0, 1]` defaulting to 0.99. The discount
|
||||
factor :math:`\gamma` as in the above equation.
|
||||
"""
|
||||
def __init__(self, n, action_value, use_target_network=True, discount_factor=0.99):
|
||||
self.n = n
|
||||
@ -185,10 +211,13 @@ class nstep_q_return:
|
||||
|
||||
def __call__(self, buffer, indexes=None):
|
||||
"""
|
||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||
each episode.
|
||||
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||
:param buffer: A :class:`tianshou.data.data_buffer`.
|
||||
:param indexes: Optional. Indexes of data points on which the full return should be computed.
|
||||
If not set, it defaults to all the data points in ``buffer``.
|
||||
Note that if it's the index of a sampled minibatch, it doesn't have to be in order within
|
||||
each episode.
|
||||
|
||||
:return: A dict with key 'return' and value the computed returns corresponding to ``indexes``.
|
||||
"""
|
||||
indexes = indexes or buffer.index
|
||||
episodes = buffer.data
|
||||
@ -224,3 +253,28 @@ class nstep_q_return:
|
||||
returns.append([])
|
||||
|
||||
return {'return': returns}
|
||||
|
||||
|
||||
class gae_lambda:
|
||||
"""
|
||||
Generalized Advantage Estimation (Schulman, 15) to compute advantage. To be implemented.
|
||||
"""
|
||||
def __init__(self, T, value_function):
|
||||
self.T = T
|
||||
self.value_function = value_function
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
def __call__(self, buffer, indexes=None):
|
||||
"""
|
||||
To be implemented
|
||||
|
||||
:param buffer: A :class:`tianshou.data.data_buffer`.
|
||||
:param indexes: Optional. Indexes of data points on which the full return should be computed.
|
||||
If not set, it defaults to all the data points in ``buffer``.
|
||||
Note that if it's the index of a sampled minibatch, it doesn't have to be in order within
|
||||
each episode.
|
||||
|
||||
:return: A dict with key 'advantage' and value the computed advantages corresponding to ``indexes``.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from os.path import dirname, basename, isfile
|
||||
import glob
|
||||
modules = glob.glob(dirname(__file__)+"/*.py")
|
||||
__all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]
|
||||
# from os.path import dirname, basename, isfile
|
||||
# import glob
|
||||
# modules = glob.glob(dirname(__file__)+"/*.py")
|
||||
# __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]
|
||||
|
||||
@ -2,25 +2,31 @@
|
||||
|
||||
class DataBufferBase(object):
|
||||
"""
|
||||
base class for data buffer, including replay buffer as in DQN and batched dataset as in on-policy algos
|
||||
Base class for data buffer, including replay buffer as used by DQN
|
||||
and batched dataset as used by on-policy algorithms.
|
||||
Our data buffer adopts a memory-efficient implementation where raw data are always stored in a
|
||||
sequential manner, and an additional set of index is used to indicate the valid data points
|
||||
in the data buffer.
|
||||
|
||||
The raw data and index are both organized in a two-level architecture as lists of lists, where
|
||||
the high-level lists correspond to episodes and low-level lists correspond to the data within
|
||||
each episode.
|
||||
|
||||
Mandatory methods for a data buffer class are:
|
||||
|
||||
- :func:`add`. It adds one timestep of data to the data buffer.
|
||||
|
||||
- :func:`clear`. It empties the data buffer.
|
||||
|
||||
- :func:`sample`. It samples one minibatch of data and returns the index of the sampled data\
|
||||
points, not the raw data.
|
||||
"""
|
||||
def add(self, frame):
|
||||
raise NotImplementedError()
|
||||
|
||||
def clear(self):
|
||||
"""Empties the data buffer, usually used in batch set but not in replay buffer."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def sample(self, batch_size):
|
||||
prob_episode = np.array(self.index_lengths) * 1. / self.size
|
||||
num_episodes = len(self.index)
|
||||
sampled_index = [[] for _ in range(num_episodes)]
|
||||
|
||||
for _ in range(batch_size):
|
||||
# sample which episode
|
||||
sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode))
|
||||
|
||||
# sample which data point within the sampled episode
|
||||
sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i]))
|
||||
sampled_index[sampled_episode_i].append(sampled_frame_i)
|
||||
|
||||
return sampled_index
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -9,12 +9,19 @@ ACTION = 1
|
||||
REWARD = 2
|
||||
DONE = 3
|
||||
|
||||
|
||||
class BatchSet(DataBufferBase):
|
||||
"""
|
||||
class for batched dataset as used in on-policy algos
|
||||
Class for batched dataset as used in on-policy algorithms, where a batch of data is first collected
|
||||
with the current policy, several optimization steps are then conducted on this batch of data and the
|
||||
data are then discarded and collected again.
|
||||
|
||||
:param nstep: An int defaulting to 1. The number of timesteps to lookahead for temporal difference computation.
|
||||
Only continuous data pieces longer than this number or already terminated ones are
|
||||
considered valid data points.
|
||||
"""
|
||||
def __init__(self, nstep=None):
|
||||
self.nstep = nstep or 1 # RL has to look ahead at least one timestep
|
||||
def __init__(self, nstep=1):
|
||||
self.nstep = nstep # RL has to look ahead at least one timestep
|
||||
|
||||
self.data = [[]]
|
||||
self.index = [[]]
|
||||
@ -25,6 +32,11 @@ class BatchSet(DataBufferBase):
|
||||
self.index_lengths = [0] # for sampling
|
||||
|
||||
def add(self, frame):
|
||||
"""
|
||||
Adds one frame of data to the buffer.
|
||||
|
||||
:param frame: A tuple of (observation, action, reward, done_flag).
|
||||
"""
|
||||
self.data[-1].append(frame)
|
||||
|
||||
has_enough_frames = len(self.data[-1]) > self.nstep
|
||||
@ -48,6 +60,9 @@ class BatchSet(DataBufferBase):
|
||||
self.index_lengths[-1] += 1
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Empties the data buffer and prepares to collect a new batch of data.
|
||||
"""
|
||||
del self.data
|
||||
del self.index
|
||||
del self.index_lengths
|
||||
@ -61,7 +76,17 @@ class BatchSet(DataBufferBase):
|
||||
self.index_lengths = [0]
|
||||
|
||||
def sample(self, batch_size):
|
||||
# TODO: move unified properties and methods to base. but this depends on how to deal with nstep
|
||||
"""
|
||||
Performs uniform random sampling on ``self.index``. For simplicity, we do random sampling with replacement
|
||||
for now with time O(``batch_size``). Fastest sampling without replacement seems to have to be of time
|
||||
O(``batch_size`` * log(num_episodes)).
|
||||
|
||||
:param batch_size: An int. The size of the minibatch.
|
||||
|
||||
:return: A list of list of the sampled indexes. Episodes without sampled data points
|
||||
correspond to empty sub-lists.
|
||||
"""
|
||||
# TODO: move unified properties and methods to base. but this may depend on how to deal with nstep
|
||||
|
||||
prob_episode = np.array(self.index_lengths) * 1. / self.size
|
||||
num_episodes = len(self.index)
|
||||
@ -78,6 +103,14 @@ class BatchSet(DataBufferBase):
|
||||
return sampled_index
|
||||
|
||||
def statistics(self, discount_factor=0.99):
|
||||
"""
|
||||
Computes and prints out the statistics (e.g., discounted returns, undiscounted returns) in the batch set.
|
||||
This is useful when policies are optimized by on-policy algorithms, so the current data in
|
||||
the batch set directly reflect the performance of the current policy.
|
||||
|
||||
:param discount_factor: Optional. A float in range :math:`[0, 1]` defaulting to 0.99. The discount
|
||||
factor to compute discounted returns.
|
||||
"""
|
||||
returns = []
|
||||
undiscounted_returns = []
|
||||
|
||||
|
||||
@ -2,11 +2,10 @@ from .base import DataBufferBase
|
||||
|
||||
class ReplayBufferBase(DataBufferBase):
|
||||
"""
|
||||
base class for replay buffer.
|
||||
Base class for replay buffer.
|
||||
Compared to :class:`DataBufferBase`, it has an additional method :func:`remove`,
|
||||
which removes extra data points when the size of the data buffer exceeds capacity.
|
||||
Besides, as the practice of using such replay buffer, it's never :func:`clear` ed.
|
||||
"""
|
||||
def remove(self):
|
||||
"""
|
||||
when size exceeds capacity, removes extra data points
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -8,18 +8,21 @@ ACTION = 1
|
||||
REWARD = 2
|
||||
DONE = 3
|
||||
|
||||
|
||||
# TODO: valid data points could be less than `nstep` timesteps. Check priority replay paper!
|
||||
class VanillaReplayBuffer(ReplayBufferBase):
|
||||
"""
|
||||
vanilla replay buffer as used in (Mnih, et al., 2015).
|
||||
Frames are always continuous in temporal order. They are only removed from the beginning. This continuity
|
||||
in `self.data` could be exploited, but only in vanilla replay buffer.
|
||||
Class for vanilla replay buffer as used in (Mnih, et al., 2015).
|
||||
Frames are always continuous in temporal order. They are only removed from the beginning
|
||||
and added at the tail.
|
||||
This continuity in `self.data` could be exploited only in vanilla replay buffer.
|
||||
|
||||
:param capacity: An int. The capacity of the buffer.
|
||||
:param nstep: An int defaulting to 1. The number of timesteps to lookahead for temporal difference computation.
|
||||
Only continuous data pieces longer than this number or already terminated ones are
|
||||
considered valid data points.
|
||||
"""
|
||||
def __init__(self, capacity, nstep=1):
|
||||
"""
|
||||
:param capacity: int. capacity of the buffer.
|
||||
:param nstep: int. number of timesteps to lookahead for temporal difference
|
||||
"""
|
||||
assert capacity > 0
|
||||
self.capacity = int(capacity)
|
||||
self.nstep = nstep
|
||||
@ -34,8 +37,9 @@ class VanillaReplayBuffer(ReplayBufferBase):
|
||||
|
||||
def add(self, frame):
|
||||
"""
|
||||
add one frame to the buffer.
|
||||
:param frame: tuple, (observation, action, reward, done_flag).
|
||||
Adds one frame of data to the buffer.
|
||||
|
||||
:param frame: A tuple of (observation, action, reward, done_flag).
|
||||
"""
|
||||
self.data[-1].append(frame)
|
||||
|
||||
@ -65,17 +69,17 @@ class VanillaReplayBuffer(ReplayBufferBase):
|
||||
|
||||
def remove(self):
|
||||
"""
|
||||
remove data until `self.size` <= `self.capacity`
|
||||
Removes data from the buffer until ``self.size <= self.capacity``.
|
||||
"""
|
||||
if self.size:
|
||||
while self.size > self.capacity:
|
||||
self.remove_oldest()
|
||||
self._remove_oldest()
|
||||
else:
|
||||
logging.warning('Attempting to remove from empty buffer!')
|
||||
|
||||
def remove_oldest(self):
|
||||
def _remove_oldest(self):
|
||||
"""
|
||||
remove the oldest data point, in this case, just the oldest frame. Empty episodes are also removed
|
||||
Removes the oldest data point, in this case, just the oldest frame. Empty episodes are also removed
|
||||
if resulted from removal.
|
||||
"""
|
||||
self.index[0].pop() # note that all index of frames in the first episode are shifted forward by 1
|
||||
@ -98,12 +102,14 @@ class VanillaReplayBuffer(ReplayBufferBase):
|
||||
|
||||
def sample(self, batch_size):
|
||||
"""
|
||||
uniform random sampling on `self.index`. For simplicity, we do random sampling with replacement
|
||||
for now with time O(`batch_size`). Fastest sampling without replacement seems to have to be of time
|
||||
O(`batch_size` * log(num_episodes)).
|
||||
:param batch_size: int.
|
||||
:return: sampled index, same structure as `self.index`. Episodes without sampled data points
|
||||
correspond to empty sub-lists.
|
||||
Performs uniform random sampling on ``self.index``. For simplicity, we do random sampling with replacement
|
||||
for now with time O(``batch_size``). Fastest sampling without replacement seems to have to be of time
|
||||
O(``batch_size`` * log(num_episodes)).
|
||||
|
||||
:param batch_size: An int. The size of the minibatch.
|
||||
|
||||
:return: A list of list of the sampled indexes. Episodes without sampled data points
|
||||
correspond to empty sub-lists.
|
||||
"""
|
||||
prob_episode = np.array(self.index_lengths) * 1. / self.size
|
||||
num_episodes = len(self.index)
|
||||
|
||||
@ -5,18 +5,24 @@ import itertools
|
||||
from .data_buffer.replay_buffer_base import ReplayBufferBase
|
||||
from .data_buffer.batch_set import BatchSet
|
||||
from .utils import internal_key_match
|
||||
from ..core.policy.deterministic import Deterministic
|
||||
|
||||
|
||||
class DataCollector(object):
|
||||
"""
|
||||
A utility class to manage the data flow during the interaction between the policy and the environment.
|
||||
It stores data into ``data_buffer``, processes the reward signals and returns the feed_dict for tf graph running.
|
||||
It stores data into ``data_buffer``, processes the reward signals and returns the feed_dict for
|
||||
tf graph running.
|
||||
|
||||
:param env:
|
||||
:param policy:
|
||||
:param data_buffer:
|
||||
:param process_functions:
|
||||
:param managed_networks:
|
||||
:param env: An environment.
|
||||
:param policy: A :class:`tianshou.core.policy`.
|
||||
:param data_buffer: A :class:`tianshou.data.data_buffer`.
|
||||
:param process_functions: A list of callables in :mod:`tianshou.data.advantage_estimation`
|
||||
to process rewards.
|
||||
:param managed_networks: A list of networks of :class:`tianshou.core.policy` and/or
|
||||
:class:`tianshou.core.value_function`. The networks you want this class to manage. This class
|
||||
will automatically generate the feed_dict for all the placeholders in the ``managed_placeholders``
|
||||
of all networks in this list.
|
||||
"""
|
||||
def __init__(self, env, policy, data_buffer, process_functions, managed_networks):
|
||||
self.env = env
|
||||
@ -42,6 +48,23 @@ class DataCollector(object):
|
||||
self.step_count_this_episode = 0
|
||||
|
||||
def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, auto_clear=True, episode_cutoff=None):
|
||||
"""
|
||||
Collect data in the environment using ``self.policy``.
|
||||
|
||||
:param num_timesteps: An int specifying the number of timesteps to act. It defaults to 0 and either
|
||||
``num_timesteps`` or ``num_episodes`` could be set but not both.
|
||||
:param num_episodes: An int specifying the number of episodes to act. It defaults to 0 and either
|
||||
``num_timesteps`` or ``num_episodes`` could be set but not both.
|
||||
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||
Specifies placeholders such as dropout and batch_norm except observation and action.
|
||||
:param auto_clear: Optional. A bool defaulting to ``True``. If ``True`` then this method clears the
|
||||
``self.data_buffer`` if ``self.data_buffer`` is an instance of
|
||||
:class:`tianshou.data.data_buffer.BatchSet.` and does nothing if it's not that instance.
|
||||
If set to ``False`` then the aforementioned auto clearing behavior is disabled.
|
||||
:param episode_cutoff: Optional. An int. The maximum number of timesteps in one episode. This is
|
||||
useful when the environment has no terminal states or a single episode could be prohibitively long.
|
||||
If set than all episodes are forced to stop beyond this number to timesteps.
|
||||
"""
|
||||
assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\
|
||||
"One and only one collection number specification permitted!"
|
||||
|
||||
@ -89,7 +112,20 @@ class DataCollector(object):
|
||||
for processor in self.process_functions:
|
||||
self.data.update(processor(self.data_buffer))
|
||||
|
||||
def next_batch(self, batch_size, standardize_advantage=None):
|
||||
return
|
||||
|
||||
def next_batch(self, batch_size, standardize_advantage=True):
|
||||
"""
|
||||
Constructs and returns the feed_dict of data to be used with ``sess.run``.
|
||||
|
||||
:param batch_size: An int. The size of one minibatch.
|
||||
:param standardize_advantage: Optional. A bool but defaulting to ``True``.
|
||||
If ``True``, then this method standardize advantages if advantage is required by the networks.
|
||||
If ``False`` then this method will never standardize advantage.
|
||||
|
||||
:return: A dict in the format of conventional feed_dict in tf, with keys the placeholders and
|
||||
values the numpy arrays.
|
||||
"""
|
||||
sampled_index = self.data_buffer.sample(batch_size)
|
||||
if self.process_mode == 'sample':
|
||||
for processor in self.process_functions:
|
||||
@ -128,8 +164,7 @@ class DataCollector(object):
|
||||
else:
|
||||
raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name)))
|
||||
|
||||
auto_standardize = (standardize_advantage is None) and self.require_advantage
|
||||
if standardize_advantage or auto_standardize:
|
||||
if standardize_advantage:
|
||||
if self.require_advantage:
|
||||
advantage_value = feed_dict[self.required_placeholders['advantage']]
|
||||
advantage_mean = np.mean(advantage_value)
|
||||
@ -140,10 +175,21 @@ class DataCollector(object):
|
||||
|
||||
return feed_dict
|
||||
|
||||
def denoise_action(self, feed_dict):
|
||||
def denoise_action(self, feed_dict, my_feed_dict={}):
|
||||
"""
|
||||
Recompute the actions of deterministic policies without exploration noise, hence denoising.
|
||||
It modifies ``feed_dict`` **in place** and has no return value.
|
||||
This is useful in, e.g., DDPG since the stored action in ``self.data_buffer`` is the sampled
|
||||
action with additional exploration noise.
|
||||
|
||||
:param feed_dict: A dict. It has to be the dict returned by :func:`next_batch` by this class.
|
||||
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||
Specifies placeholders such as dropout and batch_norm except observation and action.
|
||||
"""
|
||||
assert isinstance(self.policy, Deterministic), 'denoise_action() could only be called' \
|
||||
'with deterministic policies'
|
||||
observation = feed_dict[self.required_placeholders['observation']]
|
||||
action_mean = self.policy.eval_action(observation)
|
||||
action_mean = self.policy.eval_action(observation, my_feed_dict)
|
||||
feed_dict[self.required_placeholders['action']] = action_mean
|
||||
|
||||
return
|
||||
|
||||
@ -5,10 +5,31 @@ import logging
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_factor=0.99, seed=0, episode_cutoff=None):
|
||||
def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0,
|
||||
discount_factor=0.99, seed=0, episode_cutoff=None):
|
||||
"""
|
||||
Tests the policy in the environment and record and prints out the performance. This is useful when the policy
|
||||
is trained with off-policy algorithms and thus the rewards in the data buffer does not reflect the
|
||||
performance of the current policy.
|
||||
|
||||
:param policy: A :class:`tianshou.core.policy`. The current policy being optimized.
|
||||
:param env: An environment.
|
||||
:param num_timesteps: An int specifying the number of timesteps to test the policy.
|
||||
It defaults to 0 and either
|
||||
``num_timesteps`` or ``num_episodes`` could be set but not both.
|
||||
:param num_episodes: An int specifying the number of episodes to test the policy.
|
||||
It defaults to 0 and either
|
||||
``num_timesteps`` or ``num_episodes`` could be set but not both.
|
||||
:param discount_factor: Optional. A float in range :math:`[0, 1]` defaulting to 0.99. The discount
|
||||
factor to compute discounted returns.
|
||||
:param seed: An non-negative int. The seed to seed the environment as ``env.seed(seed)``.
|
||||
:param episode_cutoff: Optional. An int. The maximum number of timesteps in one episode. This is
|
||||
useful when the environment has no terminal states or a single episode could be prohibitively long.
|
||||
If set than all episodes are forced to stop beyond this number to timesteps.
|
||||
"""
|
||||
assert sum([num_episodes > 0, num_timesteps > 0]) == 1, \
|
||||
'One and only one collection number specification permitted!'
|
||||
assert seed >= 0
|
||||
|
||||
# make another env as the original is for training data collection
|
||||
env_id = env.spec.id
|
||||
@ -82,4 +103,4 @@ def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_fa
|
||||
print('Mean undiscounted return: {}'.format(mean_undiscounted_return))
|
||||
|
||||
# clear scene
|
||||
env_.close()
|
||||
env_.close()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user