seealso and change policy dir structure

This commit is contained in:
Trinkle23897 2020-04-09 21:36:53 +08:00
parent 6da80e045a
commit 19f2cce294
12 changed files with 102 additions and 32 deletions

View File

@ -15,10 +15,12 @@ class Batch(object):
>>> data.b >>> data.b
[5, 5] [5, 5]
>>> data.b = np.array([3, 4, 5]) >>> data.b = np.array([3, 4, 5])
>>> len(data.b) >>> print(data)
3 Batch(
>>> data.b[-1] a: 4,
5 b: [3 4 5],
c: 2312312,
)
In short, you can define a :class:`Batch` with any key-value pair. The In short, you can define a :class:`Batch` with any key-value pair. The
current implementation of Tianshou typically use 6 keys in current implementation of Tianshou typically use 6 keys in

View File

@ -9,6 +9,7 @@ class ReplayBuffer(object):
``numpy.ndarray``. Here is the usage: ``numpy.ndarray``. Here is the usage:
:: ::
>>> import numpy as np
>>> from tianshou.data import ReplayBuffer >>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20) >>> buf = ReplayBuffer(size=20)
>>> for i in range(3): >>> for i in range(3):
@ -48,10 +49,15 @@ class ReplayBuffer(object):
>>> for i in range(16): >>> for i in range(16):
... done = i % 5 == 0 ... done = i % 5 == 0
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=0, info={}) ... buf.add(obs=i, act=i, rew=i, done=done, obs_next=0, info={})
>>> print(buf.obs) >>> print(buf)
[ 9. 10. 11. 12. 13. 14. 15. 7. 8.] ReplayBuffer(
>>> print(buf.done) obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
[0. 1. 0. 0. 0. 0. 1. 0. 0.] act: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
rew: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
done: [0. 1. 0. 0. 0. 0. 1. 0. 0.],
obs_next: [0. 0. 0. 0. 0. 0. 0. 0. 0.],
info: [{} {} {} {} {} {} {} {} {}],
)
>>> index = np.arange(len(buf)) >>> index = np.arange(len(buf))
>>> print(buf.get_stack(index, 'obs')) >>> print(buf.get_stack(index, 'obs'))
[[ 7. 7. 8. 9.] [[ 7. 7. 8. 9.]
@ -65,7 +71,7 @@ class ReplayBuffer(object):
[ 7. 7. 7. 8.]] [ 7. 7. 7. 8.]]
>>> # here is another way to get the stacked data >>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next) >>> # (stack only for obs and obs_next)
>>> sum(sum(buf.get_stack(index, 'obs') - buf[index].obs)) >>> abs(buf.get_stack(index, 'obs') - buf[index].obs).sum().sum()
0.0 0.0
""" """
@ -200,6 +206,11 @@ class ListReplayBuffer(ReplayBuffer):
"""The function of :class:`~tianshou.data.ListReplayBuffer` is almost the """The function of :class:`~tianshou.data.ListReplayBuffer` is almost the
same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``. :class:`~tianshou.data.ListReplayBuffer` is based on ``list``.
.. seealso::
Please refer to :class:`~tianshou.data.ListReplayBuffer` for more
detailed explanation.
""" """
def __init__(self): def __init__(self):

View File

@ -94,8 +94,12 @@ class BaseVectorEnv(ABC, gym.Wrapper):
class VectorEnv(BaseVectorEnv): class VectorEnv(BaseVectorEnv):
"""Dummy vectorized environment wrapper, implemented in for-loop. The usage """Dummy vectorized environment wrapper, implemented in for-loop.
is in :class:`~tianshou.env.BaseVectorEnv`.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
""" """
def __init__(self, env_fns): def __init__(self, env_fns):
@ -170,8 +174,12 @@ def worker(parent, p, env_fn_wrapper):
class SubprocVectorEnv(BaseVectorEnv): class SubprocVectorEnv(BaseVectorEnv):
"""Vectorized environment wrapper based on subprocess. The usage is in """Vectorized environment wrapper based on subprocess.
:class:`~tianshou.env.BaseVectorEnv`.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
""" """
def __init__(self, env_fns): def __init__(self, env_fns):
@ -247,8 +255,12 @@ class RayVectorEnv(BaseVectorEnv):
"""Vectorized environment wrapper based on """Vectorized environment wrapper based on
`ray <https://github.com/ray-project/ray>`_. However, according to our `ray <https://github.com/ray-project/ray>`_. However, according to our
test, it is about two times slower than test, it is about two times slower than
:class:`~tianshou.env.SubprocVectorEnv`. The usage is in :class:`~tianshou.env.SubprocVectorEnv`.
:class:`~tianshou.env.BaseVectorEnv`.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
""" """
def __init__(self, env_fns): def __init__(self, env_fns):

View File

@ -1,11 +1,11 @@
from tianshou.policy.base import BasePolicy from tianshou.policy.base import BasePolicy
from tianshou.policy.dqn import DQNPolicy from tianshou.policy.modelfree.dqn import DQNPolicy
from tianshou.policy.pg import PGPolicy from tianshou.policy.modelfree.pg import PGPolicy
from tianshou.policy.a2c import A2CPolicy from tianshou.policy.modelfree.a2c import A2CPolicy
from tianshou.policy.ddpg import DDPGPolicy from tianshou.policy.modelfree.ddpg import DDPGPolicy
from tianshou.policy.ppo import PPOPolicy from tianshou.policy.modelfree.ppo import PPOPolicy
from tianshou.policy.td3 import TD3Policy from tianshou.policy.modelfree.td3 import TD3Policy
from tianshou.policy.sac import SACPolicy from tianshou.policy.modelfree.sac import SACPolicy
__all__ = [ __all__ = [
'BasePolicy', 'BasePolicy',

View File

View File

@ -21,6 +21,11 @@ class A2CPolicy(PGPolicy):
:param float ent_coef: weight for entropy loss, defaults to 0.01. :param float ent_coef: weight for entropy loss, defaults to 0.01.
:param float max_grad_norm: clipping gradients in back propagation, :param float max_grad_norm: clipping gradients in back propagation,
defaults to ``None``. defaults to ``None``.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
""" """
def __init__(self, actor, critic, optim, def __init__(self, actor, critic, optim,
@ -44,8 +49,10 @@ class A2CPolicy(PGPolicy):
* ``dist`` the action distribution. * ``dist`` the action distribution.
* ``state`` the hidden state. * ``state`` the hidden state.
More information can be found at .. seealso::
:meth:`~tianshou.policy.BasePolicy.__call__`.
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
more detailed explanation.
""" """
logits, h = self.actor(batch.obs, state=state, info=batch.info) logits, h = self.actor(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple): if isinstance(logits, tuple):

View File

@ -28,6 +28,11 @@ class DDPGPolicy(BasePolicy):
defaults to ``False``. defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy, :param bool ignore_done: ignore the done flag while training the policy,
defaults to ``False``. defaults to ``False``.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
""" """
def __init__(self, actor, actor_optim, critic, critic_optim, def __init__(self, actor, actor_optim, critic, critic_optim,
@ -104,8 +109,10 @@ class DDPGPolicy(BasePolicy):
* ``act`` the action. * ``act`` the action.
* ``state`` the hidden state. * ``state`` the hidden state.
More information can be found at .. seealso::
:meth:`~tianshou.policy.BasePolicy.__call__`.
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
more detailed explanation.
""" """
model = getattr(self, model) model = getattr(self, model)
obs = getattr(batch, input) obs = getattr(batch, input)

View File

@ -18,6 +18,11 @@ class DQNPolicy(BasePolicy):
ahead. ahead.
:param int target_update_freq: the target network update frequency (``0`` :param int target_update_freq: the target network update frequency (``0``
if you do not use the target network). if you do not use the target network).
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
""" """
def __init__(self, model, optim, discount_factor=0.99, def __init__(self, model, optim, discount_factor=0.99,
@ -106,8 +111,10 @@ class DQNPolicy(BasePolicy):
* ``logits`` the network's raw output. * ``logits`` the network's raw output.
* ``state`` the hidden state. * ``state`` the hidden state.
More information can be found at .. seealso::
:meth:`~tianshou.policy.BasePolicy.__call__`.
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
more detailed explanation.
""" """
model = getattr(self, model) model = getattr(self, model)
obs = getattr(batch, input) obs = getattr(batch, input)

View File

@ -13,6 +13,11 @@ class PGPolicy(BasePolicy):
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model. :param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param torch.distributions.Distribution dist_fn: for computing the action. :param torch.distributions.Distribution dist_fn: for computing the action.
:param float discount_factor: in [0, 1]. :param float discount_factor: in [0, 1].
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
""" """
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
@ -50,8 +55,10 @@ class PGPolicy(BasePolicy):
* ``dist`` the action distribution. * ``dist`` the action distribution.
* ``state`` the hidden state. * ``state`` the hidden state.
More information can be found at .. seealso::
:meth:`~tianshou.policy.BasePolicy.__call__`.
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
more detailed explanation.
""" """
logits, h = self.model(batch.obs, state=state, info=batch.info) logits, h = self.model(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple): if isinstance(logits, tuple):

View File

@ -26,6 +26,11 @@ class PPOPolicy(PGPolicy):
:param float ent_coef: weight for entropy loss, defaults to 0.01. :param float ent_coef: weight for entropy loss, defaults to 0.01.
:param action_range: the action range (minimum, maximum). :param action_range: the action range (minimum, maximum).
:type action_range: [float, float] :type action_range: [float, float]
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
""" """
def __init__(self, actor, critic, optim, dist_fn, def __init__(self, actor, critic, optim, dist_fn,
@ -70,8 +75,10 @@ class PPOPolicy(PGPolicy):
* ``dist`` the action distribution. * ``dist`` the action distribution.
* ``state`` the hidden state. * ``state`` the hidden state.
More information can be found at .. seealso::
:meth:`~tianshou.policy.BasePolicy.__call__`.
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
more detailed explanation.
""" """
model = getattr(self, model) model = getattr(self, model)
logits, h = model(batch.obs, state=state, info=batch.info) logits, h = model(batch.obs, state=state, info=batch.info)

View File

@ -33,6 +33,11 @@ class SACPolicy(DDPGPolicy):
defaults to ``False``. defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy, :param bool ignore_done: ignore the done flag while training the policy,
defaults to ``False``. defaults to ``False``.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
""" """
def __init__(self, actor, actor_optim, critic1, critic1_optim, def __init__(self, actor, actor_optim, critic1, critic1_optim,

View File

@ -37,6 +37,11 @@ class TD3Policy(DDPGPolicy):
defaults to ``False``. defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy, :param bool ignore_done: ignore the done flag while training the policy,
defaults to ``False``. defaults to ``False``.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
""" """
def __init__(self, actor, actor_optim, critic1, critic1_optim, def __init__(self, actor, actor_optim, critic1, critic1_optim,