finish all API docs, first version.

This commit is contained in:
haoshengzou 2018-04-15 17:41:43 +08:00
parent 8c108174b6
commit 5f979caf58
15 changed files with 296 additions and 138 deletions

View File

@ -4,9 +4,11 @@ import tensorflow as tf
def ppo_clip(policy, clip_param):
"""
Builds the graph of clipped loss :math:`L^{CLIP}` as in the
`Link PPO paper <https://arxiv.org/pdf/1707.06347.pdf>`_, which is basically
:math:`-\min(r_t(\\theta)Aˆt, clip(r_t(\\theta), 1 - \epsilon, 1 + \epsilon)Aˆt)`.
`PPO paper <https://arxiv.org/pdf/1707.06347.pdf>`_, which is basically
:math:`-\min(r_t(\\theta)A_t, \mathrm{clip}(r_t(\\theta), 1 - \epsilon, 1 + \epsilon)A_t)`.
We minimize the objective instead of maximizing, hence the leading negative sign.
It creates an action placeholder and an advantage placeholder and adds into the ``managed_placeholders``
of the ``policy``.
:param policy: A :class:`tianshou.core.policy` to be optimized.
:param clip param: A float or Tensor of type float. The :math:`\epsilon` in the loss equation.
@ -29,8 +31,10 @@ def ppo_clip(policy, clip_param):
def REINFORCE(policy):
"""
Builds the graph of the loss function as used in vanilla policy gradient algorithms, i.e., REINFORCE.
The loss is basically :math:`\log \pi(a|s) A^t`.
The loss is basically :math:`\log \pi(a|s) A_t`.
We minimize the objective instead of maximizing, hence the leading negative sign.
It creates an action placeholder and an advantage placeholder and adds into the ``managed_placeholders``
of the ``policy``.
:param policy: A :class:`tianshou.core.policy` to be optimized.
@ -50,6 +54,8 @@ def REINFORCE(policy):
def value_mse(value_function):
"""
Builds the graph of L2 loss on value functions for, e.g., training critics or DQN.
It creates an placeholder for the target value adds it into the ``managed_placeholders``
of the ``value_function``.
:param value_function: A :class:`tianshou.core.value_function` to be optimized.

View File

@ -3,7 +3,7 @@ import tensorflow as tf
def DPG(policy, action_value):
"""
Constructs the gradient Tensor of `Link deterministic policy gradient <https://arxiv.org/pdf/1509.02971.pdf>`_.
Constructs the gradient Tensor of `deterministic policy gradient <https://arxiv.org/pdf/1509.02971.pdf>`_.
:param policy: A :class:`tianshou.core.policy.Deterministic` to be optimized.
:param action_value: A :class:`tianshou.core.value_function.ActionValue` to guide the optimization of `policy`.

View File

@ -146,7 +146,7 @@ class Deterministic(PolicyBase):
"""
sess = tf.get_default_session()
feed_dict = {self.observation_placeholder: observation}
feed_dict = {self.observation_placeholder: observation}.update(my_feed_dict)
action = sess.run(self.action, feed_dict=feed_dict)
return action
@ -164,7 +164,7 @@ class Deterministic(PolicyBase):
"""
sess = tf.get_default_session()
feed_dict = {self.observation_placeholder: observation}
feed_dict = {self.observation_placeholder: observation}.update(my_feed_dict)
action = sess.run(self.action_old, feed_dict=feed_dict)
return action

View File

@ -66,6 +66,11 @@ class GaussianWhiteNoiseProcess(AnnealedGaussianProcess):
self.size = size
def sample(self):
"""
Draws one sample from the random process.
:return: A numpy array. The drawn sample.
"""
sample = np.random.normal(self.mu, self.current_sigma, self.size)
self.n_steps += 1
return sample
@ -102,6 +107,11 @@ class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess):
self.reset_states()
def sample(self):
"""
Draws one sample from the random process.
:return: A numpy array. The drawn sample.
"""
x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + \
self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size)
self.x_prev = x
@ -109,4 +119,7 @@ class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess):
return x
def reset_states(self):
"""
Reset ``self.x_prev`` to be ``self.x0``.
"""
self.x_prev = self.x0 if self.x0 is not None else np.zeros(self.size)

View File

@ -17,11 +17,11 @@ def identify_dependent_variables(tensor, candidate_variables):
def get_soft_update_op(update_fraction, including_nets, excluding_nets=None):
"""
Builds the graph op to softly update the "old net" of policies and value_functions, as suggested in
`Link DDPG <https://arxiv.org/pdf/1509.02971.pdf>`_. It updates the :class:`tf.Variable` s in the old net,
`DDPG <https://arxiv.org/pdf/1509.02971.pdf>`_. It updates the :class:`tf.Variable` s in the old net,
:math:`\\theta'` with the :class:`tf.Variable` s in the current network, :math:`\\theta` as
:math:`\\theta' = \tau \\theta + (1 - \tau) \\theta'`.
:math:`\\theta' = \\tau \\theta + (1 - \\tau) \\theta'`.
:param update_fraction: A float in range :math:`[0, 1]`. Corresponding to the :math:`\tau` in the update equation.
:param update_fraction: A float in range :math:`[0, 1]`. Corresponding to the :math:`\\tau` in the update equation.
:param including_nets: A list of policies and/or value_functions. All :class:`tf.Variable` s in these networks
are included for update. Shared Variables are updated only once in case of layer sharing among the networks.
:param excluding_nets: Optional. A list of policies and/or value_functions defaulting to ``None``.

View File

@ -110,13 +110,13 @@ class ActionValue(ValueFunctionBase):
class DQN(ValueFunctionBase):
"""
Class for the special action value function DQN. Instead of feeding s and a to the network to get a value,
DQN feeds s to the network and gets at the last layer Q(s, *) for all actions under this state. Still, as
DQN feeds s to the network and gets at the last layer Q(s, \*) for all actions under this state. Still, as
:class:`ActionValue`, this class still builds the Q(s, a) value Tensor. It can only be used with discrete
(and finite) action spaces.
:param network_callable: A Python callable returning (action head, value head). When called it builds
the tf graph and returns a Tensor of Q(s, *) on the value head.
:param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder for s in Q(s, *)
the tf graph and returns a Tensor of Q(s, \*) on the value head.
:param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder for s in Q(s, \*)
in the network graph.
:param has_old_net: A bool defaulting to ``False``. If true this class will create another graph with another
set of :class:`tf.Variable` s to be the "old net". The "old net" could be the target networks as in DQN
@ -218,12 +218,12 @@ class DQN(ValueFunctionBase):
@property
def value_tensor_all_actions(self):
"""The Tensor for Q(s, *)"""
"""The Tensor for Q(s, \*)"""
return self._value_tensor_all_actions
def eval_value_all_actions(self, observation, my_feed_dict={}):
"""
Evaluate values Q(s, *) in minibatch using the current network.
Evaluate values Q(s, \*) in minibatch using the current network.
:param observation: An array-like, of shape (batch_size,) + observation_shape.
:param my_feed_dict: Optional. A dict defaulting to empty.
@ -236,7 +236,7 @@ class DQN(ValueFunctionBase):
def eval_value_all_actions_old(self, observation, my_feed_dict={}):
"""
Evaluate values Q(s, *) in minibatch using the old net.
Evaluate values Q(s, \*) in minibatch using the old net.
:param observation: An array-like, of shape (batch_size,) + observation_shape.
:param my_feed_dict: Optional. A dict defaulting to empty.

View File

@ -1,26 +0,0 @@
# TODO:
Notice that we will separate actor and critic, and batch will collect data for optimizing policy while replay will collect data for optimizing critic.
# Batch
YouQiaoben
fix as stated in ppo_example.py
# Replay
ShihongSong
a Replay.py file. must have collect() and next_batch() methods for training.
integrate previous ReplayBuffer codes.
# adv_estimate
YouQiaoben (gae_lambda), ShihongSong(dqn after policy.DQN)
seems to be direct python functions. also may write it in a functional form.

View File

@ -10,11 +10,17 @@ DONE = 3
# TODO: add discount_factor... maybe make it to be a global config?
def full_return(buffer, indexes=None):
"""
naively compute full return
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
:param indexes: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
each episode.
:return: dict with key 'return' and value the computed returns corresponding to `index`.
Naively compute full undiscounted return on episodic data, :math:`G_t = \sum_{t=0}^T r_t`.
This function will print a warning when some of the episodes
in ``buffer`` has not yet terminated.
:param buffer: A :class:`tianshou.data.data_buffer`.
:param indexes: Optional. Indexes of data points on which the full return should be computed.
If not set, it defaults to all the data points in ``buffer``.
Note that if it's the index of a sampled minibatch, it doesn't have to be in order within
each episode.
:return: A dict with key 'return' and value the computed returns corresponding to ``indexes``.
"""
indexes = indexes or buffer.index
raw_data = buffer.data
@ -46,27 +52,20 @@ def full_return(buffer, indexes=None):
return {'return': returns}
class gae_lambda:
"""
Generalized Advantage Estimation (Schulman, 15) to compute advantage
"""
def __init__(self, T, value_function):
self.T = T
self.value_function = value_function
def __call__(self, buffer, index=None):
"""
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
each episode.
:return: dict with key 'advantage' and value the computed advantages corresponding to `index`.
"""
pass
class nstep_return:
"""
compute the n-step return from n-step rewards and bootstrapped value function
Compute the n-step return from n-step rewards and bootstrapped state value function V(s),
:math:`V(s_t) = r_t + \gamma r_{t+1} + ... + \gamma^{n-1} r_{t+n-1} + \gamma^n V(s_{t+n})`.
:param n: An int. The number of steps to lookahead, where :math:`n=1` will directly apply V(s) to
the next observation, as in the above equation.
:param value_function: A :class:`tianshou.core.value_function.StateValue`. The V(s) as in the
above equation
:param return_advantage: Optional. A bool defaulting to ``False``. If ``True`` than this callable
also returns the advantage function
:math:`A(s_t) = r_t + \gamma r_{t+1} + ... + \gamma^{n-1} r_{t+n-1} + \gamma^n V(s_{t+n}) - V(s_t)` when called.
:param discount_factor: Optional. A float in range :math:`[0, 1]` defaulting to 0.99. The discount
factor :math:`\gamma` as in the above equation.
"""
def __init__(self, n, value_function, return_advantage=False, discount_factor=0.99):
self.n = n
@ -76,10 +75,15 @@ class nstep_return:
def __call__(self, buffer, indexes=None):
"""
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
:param indexes: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
each episode.
:return: dict with key 'return' and value the computed returns corresponding to `index`.
:param buffer: A :class:`tianshou.data.data_buffer`.
:param indexes: Optional. Indexes of data points on which the specified return should be computed.
If not set, it defaults to all the data points in ``buffer``.
Note that if it's the index of a sampled minibatch, it doesn't have to be in order within
each episode.
:return: A dict with key 'return' and value the computed returns corresponding to ``indexes``.
If ``return_advantage`` set to ``True`` then also a key 'advantage' and value the corresponding
advantages.
"""
indexes = indexes or buffer.index
episodes = buffer.data
@ -125,7 +129,16 @@ class nstep_return:
class ddpg_return:
"""
compute the return as in DDPG. this seems to have to be special
Compute the return as in `DDPG <https://arxiv.org/pdf/1509.02971.pdf>`_,
:math:`G_t = r_t + \gamma Q'(s_{t+1}, \mu'(s_{t+1}))`, where :math:`Q'` and :math:`\mu'` are the
target networks.
:param actor: A :class:`tianshou.core.policy.Deterministic`. A deterministic policy.
:param critic: A :class:`tianshou.core.value_function.ActionValue`. An action value function Q(s, a).
:param use_target_network: Optional. A bool defaulting to ``True``. Whether to use the target networks
in the above equation.
:param discount_factor: Optional. A float in range :math:`[0, 1]` defaulting to 0.99. The discount
factor :math:`\gamma` as in the above equation.
"""
def __init__(self, actor, critic, use_target_network=True, discount_factor=0.99):
self.actor = actor
@ -135,10 +148,13 @@ class ddpg_return:
def __call__(self, buffer, indexes=None):
"""
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
each episode.
:return: dict with key 'return' and value the computed returns corresponding to `index`.
:param buffer: A :class:`tianshou.data.data_buffer`.
:param indexes: Optional. Indexes of data points on which the specified return should be computed.
If not set, it defaults to all the data points in ``buffer``.
Note that if it's the index of a sampled minibatch, it doesn't have to be in order within
each episode.
:return: A dict with key 'return' and value the computed returns corresponding to ``indexes``.
"""
indexes = indexes or buffer.index
episodes = buffer.data
@ -175,7 +191,17 @@ class ddpg_return:
class nstep_q_return:
"""
compute the n-step return for Q-learning targets
Compute the n-step return for Q-learning targets,
:math:`G_t = r_t + \gamma \max_a Q'(s_{t+1}, a)`.
:param n: An int. The number of steps to lookahead, where :math:`n=1` will directly apply :math:`Q'(s, \*)` to
the next observation, as in the above equation.
:param action_value: A :class:`tianshou.core.value_function.DQN`. The :math:`Q'(s, \*)` as in the
above equation.
:param use_target_network: Optional. A bool defaulting to ``True``. Whether to use the target networks
in the above equation.
:param discount_factor: Optional. A float in range :math:`[0, 1]` defaulting to 0.99. The discount
factor :math:`\gamma` as in the above equation.
"""
def __init__(self, n, action_value, use_target_network=True, discount_factor=0.99):
self.n = n
@ -185,10 +211,13 @@ class nstep_q_return:
def __call__(self, buffer, indexes=None):
"""
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
each episode.
:return: dict with key 'return' and value the computed returns corresponding to `index`.
:param buffer: A :class:`tianshou.data.data_buffer`.
:param indexes: Optional. Indexes of data points on which the full return should be computed.
If not set, it defaults to all the data points in ``buffer``.
Note that if it's the index of a sampled minibatch, it doesn't have to be in order within
each episode.
:return: A dict with key 'return' and value the computed returns corresponding to ``indexes``.
"""
indexes = indexes or buffer.index
episodes = buffer.data
@ -224,3 +253,28 @@ class nstep_q_return:
returns.append([])
return {'return': returns}
class gae_lambda:
"""
Generalized Advantage Estimation (Schulman, 15) to compute advantage. To be implemented.
"""
def __init__(self, T, value_function):
self.T = T
self.value_function = value_function
raise NotImplementedError()
def __call__(self, buffer, indexes=None):
"""
To be implemented
:param buffer: A :class:`tianshou.data.data_buffer`.
:param indexes: Optional. Indexes of data points on which the full return should be computed.
If not set, it defaults to all the data points in ``buffer``.
Note that if it's the index of a sampled minibatch, it doesn't have to be in order within
each episode.
:return: A dict with key 'advantage' and value the computed advantages corresponding to ``indexes``.
"""
raise NotImplementedError()

View File

@ -1,4 +1,4 @@
from os.path import dirname, basename, isfile
import glob
modules = glob.glob(dirname(__file__)+"/*.py")
__all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]
# from os.path import dirname, basename, isfile
# import glob
# modules = glob.glob(dirname(__file__)+"/*.py")
# __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]

View File

@ -2,25 +2,31 @@
class DataBufferBase(object):
"""
base class for data buffer, including replay buffer as in DQN and batched dataset as in on-policy algos
Base class for data buffer, including replay buffer as used by DQN
and batched dataset as used by on-policy algorithms.
Our data buffer adopts a memory-efficient implementation where raw data are always stored in a
sequential manner, and an additional set of index is used to indicate the valid data points
in the data buffer.
The raw data and index are both organized in a two-level architecture as lists of lists, where
the high-level lists correspond to episodes and low-level lists correspond to the data within
each episode.
Mandatory methods for a data buffer class are:
- :func:`add`. It adds one timestep of data to the data buffer.
- :func:`clear`. It empties the data buffer.
- :func:`sample`. It samples one minibatch of data and returns the index of the sampled data\
points, not the raw data.
"""
def add(self, frame):
raise NotImplementedError()
def clear(self):
"""Empties the data buffer, usually used in batch set but not in replay buffer."""
raise NotImplementedError()
def sample(self, batch_size):
prob_episode = np.array(self.index_lengths) * 1. / self.size
num_episodes = len(self.index)
sampled_index = [[] for _ in range(num_episodes)]
for _ in range(batch_size):
# sample which episode
sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode))
# sample which data point within the sampled episode
sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i]))
sampled_index[sampled_episode_i].append(sampled_frame_i)
return sampled_index
raise NotImplementedError()

View File

@ -9,12 +9,19 @@ ACTION = 1
REWARD = 2
DONE = 3
class BatchSet(DataBufferBase):
"""
class for batched dataset as used in on-policy algos
Class for batched dataset as used in on-policy algorithms, where a batch of data is first collected
with the current policy, several optimization steps are then conducted on this batch of data and the
data are then discarded and collected again.
:param nstep: An int defaulting to 1. The number of timesteps to lookahead for temporal difference computation.
Only continuous data pieces longer than this number or already terminated ones are
considered valid data points.
"""
def __init__(self, nstep=None):
self.nstep = nstep or 1 # RL has to look ahead at least one timestep
def __init__(self, nstep=1):
self.nstep = nstep # RL has to look ahead at least one timestep
self.data = [[]]
self.index = [[]]
@ -25,6 +32,11 @@ class BatchSet(DataBufferBase):
self.index_lengths = [0] # for sampling
def add(self, frame):
"""
Adds one frame of data to the buffer.
:param frame: A tuple of (observation, action, reward, done_flag).
"""
self.data[-1].append(frame)
has_enough_frames = len(self.data[-1]) > self.nstep
@ -48,6 +60,9 @@ class BatchSet(DataBufferBase):
self.index_lengths[-1] += 1
def clear(self):
"""
Empties the data buffer and prepares to collect a new batch of data.
"""
del self.data
del self.index
del self.index_lengths
@ -61,7 +76,17 @@ class BatchSet(DataBufferBase):
self.index_lengths = [0]
def sample(self, batch_size):
# TODO: move unified properties and methods to base. but this depends on how to deal with nstep
"""
Performs uniform random sampling on ``self.index``. For simplicity, we do random sampling with replacement
for now with time O(``batch_size``). Fastest sampling without replacement seems to have to be of time
O(``batch_size`` * log(num_episodes)).
:param batch_size: An int. The size of the minibatch.
:return: A list of list of the sampled indexes. Episodes without sampled data points
correspond to empty sub-lists.
"""
# TODO: move unified properties and methods to base. but this may depend on how to deal with nstep
prob_episode = np.array(self.index_lengths) * 1. / self.size
num_episodes = len(self.index)
@ -78,6 +103,14 @@ class BatchSet(DataBufferBase):
return sampled_index
def statistics(self, discount_factor=0.99):
"""
Computes and prints out the statistics (e.g., discounted returns, undiscounted returns) in the batch set.
This is useful when policies are optimized by on-policy algorithms, so the current data in
the batch set directly reflect the performance of the current policy.
:param discount_factor: Optional. A float in range :math:`[0, 1]` defaulting to 0.99. The discount
factor to compute discounted returns.
"""
returns = []
undiscounted_returns = []

View File

@ -2,11 +2,10 @@ from .base import DataBufferBase
class ReplayBufferBase(DataBufferBase):
"""
base class for replay buffer.
Base class for replay buffer.
Compared to :class:`DataBufferBase`, it has an additional method :func:`remove`,
which removes extra data points when the size of the data buffer exceeds capacity.
Besides, as the practice of using such replay buffer, it's never :func:`clear` ed.
"""
def remove(self):
"""
when size exceeds capacity, removes extra data points
:return:
"""
raise NotImplementedError()

View File

@ -8,18 +8,21 @@ ACTION = 1
REWARD = 2
DONE = 3
# TODO: valid data points could be less than `nstep` timesteps. Check priority replay paper!
class VanillaReplayBuffer(ReplayBufferBase):
"""
vanilla replay buffer as used in (Mnih, et al., 2015).
Frames are always continuous in temporal order. They are only removed from the beginning. This continuity
in `self.data` could be exploited, but only in vanilla replay buffer.
Class for vanilla replay buffer as used in (Mnih, et al., 2015).
Frames are always continuous in temporal order. They are only removed from the beginning
and added at the tail.
This continuity in `self.data` could be exploited only in vanilla replay buffer.
:param capacity: An int. The capacity of the buffer.
:param nstep: An int defaulting to 1. The number of timesteps to lookahead for temporal difference computation.
Only continuous data pieces longer than this number or already terminated ones are
considered valid data points.
"""
def __init__(self, capacity, nstep=1):
"""
:param capacity: int. capacity of the buffer.
:param nstep: int. number of timesteps to lookahead for temporal difference
"""
assert capacity > 0
self.capacity = int(capacity)
self.nstep = nstep
@ -34,8 +37,9 @@ class VanillaReplayBuffer(ReplayBufferBase):
def add(self, frame):
"""
add one frame to the buffer.
:param frame: tuple, (observation, action, reward, done_flag).
Adds one frame of data to the buffer.
:param frame: A tuple of (observation, action, reward, done_flag).
"""
self.data[-1].append(frame)
@ -65,17 +69,17 @@ class VanillaReplayBuffer(ReplayBufferBase):
def remove(self):
"""
remove data until `self.size` <= `self.capacity`
Removes data from the buffer until ``self.size <= self.capacity``.
"""
if self.size:
while self.size > self.capacity:
self.remove_oldest()
self._remove_oldest()
else:
logging.warning('Attempting to remove from empty buffer!')
def remove_oldest(self):
def _remove_oldest(self):
"""
remove the oldest data point, in this case, just the oldest frame. Empty episodes are also removed
Removes the oldest data point, in this case, just the oldest frame. Empty episodes are also removed
if resulted from removal.
"""
self.index[0].pop() # note that all index of frames in the first episode are shifted forward by 1
@ -98,12 +102,14 @@ class VanillaReplayBuffer(ReplayBufferBase):
def sample(self, batch_size):
"""
uniform random sampling on `self.index`. For simplicity, we do random sampling with replacement
for now with time O(`batch_size`). Fastest sampling without replacement seems to have to be of time
O(`batch_size` * log(num_episodes)).
:param batch_size: int.
:return: sampled index, same structure as `self.index`. Episodes without sampled data points
correspond to empty sub-lists.
Performs uniform random sampling on ``self.index``. For simplicity, we do random sampling with replacement
for now with time O(``batch_size``). Fastest sampling without replacement seems to have to be of time
O(``batch_size`` * log(num_episodes)).
:param batch_size: An int. The size of the minibatch.
:return: A list of list of the sampled indexes. Episodes without sampled data points
correspond to empty sub-lists.
"""
prob_episode = np.array(self.index_lengths) * 1. / self.size
num_episodes = len(self.index)

View File

@ -5,18 +5,24 @@ import itertools
from .data_buffer.replay_buffer_base import ReplayBufferBase
from .data_buffer.batch_set import BatchSet
from .utils import internal_key_match
from ..core.policy.deterministic import Deterministic
class DataCollector(object):
"""
A utility class to manage the data flow during the interaction between the policy and the environment.
It stores data into ``data_buffer``, processes the reward signals and returns the feed_dict for tf graph running.
It stores data into ``data_buffer``, processes the reward signals and returns the feed_dict for
tf graph running.
:param env:
:param policy:
:param data_buffer:
:param process_functions:
:param managed_networks:
:param env: An environment.
:param policy: A :class:`tianshou.core.policy`.
:param data_buffer: A :class:`tianshou.data.data_buffer`.
:param process_functions: A list of callables in :mod:`tianshou.data.advantage_estimation`
to process rewards.
:param managed_networks: A list of networks of :class:`tianshou.core.policy` and/or
:class:`tianshou.core.value_function`. The networks you want this class to manage. This class
will automatically generate the feed_dict for all the placeholders in the ``managed_placeholders``
of all networks in this list.
"""
def __init__(self, env, policy, data_buffer, process_functions, managed_networks):
self.env = env
@ -42,6 +48,23 @@ class DataCollector(object):
self.step_count_this_episode = 0
def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, auto_clear=True, episode_cutoff=None):
"""
Collect data in the environment using ``self.policy``.
:param num_timesteps: An int specifying the number of timesteps to act. It defaults to 0 and either
``num_timesteps`` or ``num_episodes`` could be set but not both.
:param num_episodes: An int specifying the number of episodes to act. It defaults to 0 and either
``num_timesteps`` or ``num_episodes`` could be set but not both.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation and action.
:param auto_clear: Optional. A bool defaulting to ``True``. If ``True`` then this method clears the
``self.data_buffer`` if ``self.data_buffer`` is an instance of
:class:`tianshou.data.data_buffer.BatchSet.` and does nothing if it's not that instance.
If set to ``False`` then the aforementioned auto clearing behavior is disabled.
:param episode_cutoff: Optional. An int. The maximum number of timesteps in one episode. This is
useful when the environment has no terminal states or a single episode could be prohibitively long.
If set than all episodes are forced to stop beyond this number to timesteps.
"""
assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\
"One and only one collection number specification permitted!"
@ -89,7 +112,20 @@ class DataCollector(object):
for processor in self.process_functions:
self.data.update(processor(self.data_buffer))
def next_batch(self, batch_size, standardize_advantage=None):
return
def next_batch(self, batch_size, standardize_advantage=True):
"""
Constructs and returns the feed_dict of data to be used with ``sess.run``.
:param batch_size: An int. The size of one minibatch.
:param standardize_advantage: Optional. A bool but defaulting to ``True``.
If ``True``, then this method standardize advantages if advantage is required by the networks.
If ``False`` then this method will never standardize advantage.
:return: A dict in the format of conventional feed_dict in tf, with keys the placeholders and
values the numpy arrays.
"""
sampled_index = self.data_buffer.sample(batch_size)
if self.process_mode == 'sample':
for processor in self.process_functions:
@ -128,8 +164,7 @@ class DataCollector(object):
else:
raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name)))
auto_standardize = (standardize_advantage is None) and self.require_advantage
if standardize_advantage or auto_standardize:
if standardize_advantage:
if self.require_advantage:
advantage_value = feed_dict[self.required_placeholders['advantage']]
advantage_mean = np.mean(advantage_value)
@ -140,10 +175,21 @@ class DataCollector(object):
return feed_dict
def denoise_action(self, feed_dict):
def denoise_action(self, feed_dict, my_feed_dict={}):
"""
Recompute the actions of deterministic policies without exploration noise, hence denoising.
It modifies ``feed_dict`` **in place** and has no return value.
This is useful in, e.g., DDPG since the stored action in ``self.data_buffer`` is the sampled
action with additional exploration noise.
:param feed_dict: A dict. It has to be the dict returned by :func:`next_batch` by this class.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation and action.
"""
assert isinstance(self.policy, Deterministic), 'denoise_action() could only be called' \
'with deterministic policies'
observation = feed_dict[self.required_placeholders['observation']]
action_mean = self.policy.eval_action(observation)
action_mean = self.policy.eval_action(observation, my_feed_dict)
feed_dict[self.required_placeholders['action']] = action_mean
return

View File

@ -5,10 +5,31 @@ import logging
import numpy as np
def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_factor=0.99, seed=0, episode_cutoff=None):
def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0,
discount_factor=0.99, seed=0, episode_cutoff=None):
"""
Tests the policy in the environment and record and prints out the performance. This is useful when the policy
is trained with off-policy algorithms and thus the rewards in the data buffer does not reflect the
performance of the current policy.
:param policy: A :class:`tianshou.core.policy`. The current policy being optimized.
:param env: An environment.
:param num_timesteps: An int specifying the number of timesteps to test the policy.
It defaults to 0 and either
``num_timesteps`` or ``num_episodes`` could be set but not both.
:param num_episodes: An int specifying the number of episodes to test the policy.
It defaults to 0 and either
``num_timesteps`` or ``num_episodes`` could be set but not both.
:param discount_factor: Optional. A float in range :math:`[0, 1]` defaulting to 0.99. The discount
factor to compute discounted returns.
:param seed: An non-negative int. The seed to seed the environment as ``env.seed(seed)``.
:param episode_cutoff: Optional. An int. The maximum number of timesteps in one episode. This is
useful when the environment has no terminal states or a single episode could be prohibitively long.
If set than all episodes are forced to stop beyond this number to timesteps.
"""
assert sum([num_episodes > 0, num_timesteps > 0]) == 1, \
'One and only one collection number specification permitted!'
assert seed >= 0
# make another env as the original is for training data collection
env_id = env.spec.id
@ -82,4 +103,4 @@ def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_fa
print('Mean undiscounted return: {}'.format(mean_undiscounted_return))
# clear scene
env_.close()
env_.close()