seealso and change policy dir structure

This commit is contained in:
Trinkle23897 2020-04-09 21:36:53 +08:00
parent 6da80e045a
commit 19f2cce294
12 changed files with 102 additions and 32 deletions

View File

@ -15,10 +15,12 @@ class Batch(object):
>>> data.b
[5, 5]
>>> data.b = np.array([3, 4, 5])
>>> len(data.b)
3
>>> data.b[-1]
5
>>> print(data)
Batch(
a: 4,
b: [3 4 5],
c: 2312312,
)
In short, you can define a :class:`Batch` with any key-value pair. The
current implementation of Tianshou typically use 6 keys in

View File

@ -9,6 +9,7 @@ class ReplayBuffer(object):
``numpy.ndarray``. Here is the usage:
::
>>> import numpy as np
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
@ -48,10 +49,15 @@ class ReplayBuffer(object):
>>> for i in range(16):
... done = i % 5 == 0
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=0, info={})
>>> print(buf.obs)
[ 9. 10. 11. 12. 13. 14. 15. 7. 8.]
>>> print(buf.done)
[0. 1. 0. 0. 0. 0. 1. 0. 0.]
>>> print(buf)
ReplayBuffer(
obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
act: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
rew: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
done: [0. 1. 0. 0. 0. 0. 1. 0. 0.],
obs_next: [0. 0. 0. 0. 0. 0. 0. 0. 0.],
info: [{} {} {} {} {} {} {} {} {}],
)
>>> index = np.arange(len(buf))
>>> print(buf.get_stack(index, 'obs'))
[[ 7. 7. 8. 9.]
@ -65,7 +71,7 @@ class ReplayBuffer(object):
[ 7. 7. 7. 8.]]
>>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next)
>>> sum(sum(buf.get_stack(index, 'obs') - buf[index].obs))
>>> abs(buf.get_stack(index, 'obs') - buf[index].obs).sum().sum()
0.0
"""
@ -200,6 +206,11 @@ class ListReplayBuffer(ReplayBuffer):
"""The function of :class:`~tianshou.data.ListReplayBuffer` is almost the
same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``.
.. seealso::
Please refer to :class:`~tianshou.data.ListReplayBuffer` for more
detailed explanation.
"""
def __init__(self):

View File

@ -94,8 +94,12 @@ class BaseVectorEnv(ABC, gym.Wrapper):
class VectorEnv(BaseVectorEnv):
"""Dummy vectorized environment wrapper, implemented in for-loop. The usage
is in :class:`~tianshou.env.BaseVectorEnv`.
"""Dummy vectorized environment wrapper, implemented in for-loop.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
"""
def __init__(self, env_fns):
@ -170,8 +174,12 @@ def worker(parent, p, env_fn_wrapper):
class SubprocVectorEnv(BaseVectorEnv):
"""Vectorized environment wrapper based on subprocess. The usage is in
:class:`~tianshou.env.BaseVectorEnv`.
"""Vectorized environment wrapper based on subprocess.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
"""
def __init__(self, env_fns):
@ -247,8 +255,12 @@ class RayVectorEnv(BaseVectorEnv):
"""Vectorized environment wrapper based on
`ray <https://github.com/ray-project/ray>`_. However, according to our
test, it is about two times slower than
:class:`~tianshou.env.SubprocVectorEnv`. The usage is in
:class:`~tianshou.env.BaseVectorEnv`.
:class:`~tianshou.env.SubprocVectorEnv`.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
"""
def __init__(self, env_fns):

View File

@ -1,11 +1,11 @@
from tianshou.policy.base import BasePolicy
from tianshou.policy.dqn import DQNPolicy
from tianshou.policy.pg import PGPolicy
from tianshou.policy.a2c import A2CPolicy
from tianshou.policy.ddpg import DDPGPolicy
from tianshou.policy.ppo import PPOPolicy
from tianshou.policy.td3 import TD3Policy
from tianshou.policy.sac import SACPolicy
from tianshou.policy.modelfree.dqn import DQNPolicy
from tianshou.policy.modelfree.pg import PGPolicy
from tianshou.policy.modelfree.a2c import A2CPolicy
from tianshou.policy.modelfree.ddpg import DDPGPolicy
from tianshou.policy.modelfree.ppo import PPOPolicy
from tianshou.policy.modelfree.td3 import TD3Policy
from tianshou.policy.modelfree.sac import SACPolicy
__all__ = [
'BasePolicy',

View File

View File

@ -21,6 +21,11 @@ class A2CPolicy(PGPolicy):
:param float ent_coef: weight for entropy loss, defaults to 0.01.
:param float max_grad_norm: clipping gradients in back propagation,
defaults to ``None``.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(self, actor, critic, optim,
@ -44,8 +49,10 @@ class A2CPolicy(PGPolicy):
* ``dist`` the action distribution.
* ``state`` the hidden state.
More information can be found at
:meth:`~tianshou.policy.BasePolicy.__call__`.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
more detailed explanation.
"""
logits, h = self.actor(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple):

View File

@ -28,6 +28,11 @@ class DDPGPolicy(BasePolicy):
defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to ``False``.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(self, actor, actor_optim, critic, critic_optim,
@ -104,8 +109,10 @@ class DDPGPolicy(BasePolicy):
* ``act`` the action.
* ``state`` the hidden state.
More information can be found at
:meth:`~tianshou.policy.BasePolicy.__call__`.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
more detailed explanation.
"""
model = getattr(self, model)
obs = getattr(batch, input)

View File

@ -18,6 +18,11 @@ class DQNPolicy(BasePolicy):
ahead.
:param int target_update_freq: the target network update frequency (``0``
if you do not use the target network).
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(self, model, optim, discount_factor=0.99,
@ -106,8 +111,10 @@ class DQNPolicy(BasePolicy):
* ``logits`` the network's raw output.
* ``state`` the hidden state.
More information can be found at
:meth:`~tianshou.policy.BasePolicy.__call__`.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
more detailed explanation.
"""
model = getattr(self, model)
obs = getattr(batch, input)

View File

@ -13,6 +13,11 @@ class PGPolicy(BasePolicy):
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param torch.distributions.Distribution dist_fn: for computing the action.
:param float discount_factor: in [0, 1].
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
@ -50,8 +55,10 @@ class PGPolicy(BasePolicy):
* ``dist`` the action distribution.
* ``state`` the hidden state.
More information can be found at
:meth:`~tianshou.policy.BasePolicy.__call__`.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
more detailed explanation.
"""
logits, h = self.model(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple):

View File

@ -26,6 +26,11 @@ class PPOPolicy(PGPolicy):
:param float ent_coef: weight for entropy loss, defaults to 0.01.
:param action_range: the action range (minimum, maximum).
:type action_range: [float, float]
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(self, actor, critic, optim, dist_fn,
@ -70,8 +75,10 @@ class PPOPolicy(PGPolicy):
* ``dist`` the action distribution.
* ``state`` the hidden state.
More information can be found at
:meth:`~tianshou.policy.BasePolicy.__call__`.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
more detailed explanation.
"""
model = getattr(self, model)
logits, h = model(batch.obs, state=state, info=batch.info)

View File

@ -33,6 +33,11 @@ class SACPolicy(DDPGPolicy):
defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to ``False``.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(self, actor, actor_optim, critic1, critic1_optim,

View File

@ -37,6 +37,11 @@ class TD3Policy(DDPGPolicy):
defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to ``False``.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(self, actor, actor_optim, critic1, critic1_optim,