seealso and change policy dir structure
This commit is contained in:
parent
6da80e045a
commit
19f2cce294
@ -15,10 +15,12 @@ class Batch(object):
|
||||
>>> data.b
|
||||
[5, 5]
|
||||
>>> data.b = np.array([3, 4, 5])
|
||||
>>> len(data.b)
|
||||
3
|
||||
>>> data.b[-1]
|
||||
5
|
||||
>>> print(data)
|
||||
Batch(
|
||||
a: 4,
|
||||
b: [3 4 5],
|
||||
c: 2312312,
|
||||
)
|
||||
|
||||
In short, you can define a :class:`Batch` with any key-value pair. The
|
||||
current implementation of Tianshou typically use 6 keys in
|
||||
|
@ -9,6 +9,7 @@ class ReplayBuffer(object):
|
||||
``numpy.ndarray``. Here is the usage:
|
||||
::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from tianshou.data import ReplayBuffer
|
||||
>>> buf = ReplayBuffer(size=20)
|
||||
>>> for i in range(3):
|
||||
@ -48,10 +49,15 @@ class ReplayBuffer(object):
|
||||
>>> for i in range(16):
|
||||
... done = i % 5 == 0
|
||||
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=0, info={})
|
||||
>>> print(buf.obs)
|
||||
[ 9. 10. 11. 12. 13. 14. 15. 7. 8.]
|
||||
>>> print(buf.done)
|
||||
[0. 1. 0. 0. 0. 0. 1. 0. 0.]
|
||||
>>> print(buf)
|
||||
ReplayBuffer(
|
||||
obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
|
||||
act: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
|
||||
rew: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
|
||||
done: [0. 1. 0. 0. 0. 0. 1. 0. 0.],
|
||||
obs_next: [0. 0. 0. 0. 0. 0. 0. 0. 0.],
|
||||
info: [{} {} {} {} {} {} {} {} {}],
|
||||
)
|
||||
>>> index = np.arange(len(buf))
|
||||
>>> print(buf.get_stack(index, 'obs'))
|
||||
[[ 7. 7. 8. 9.]
|
||||
@ -65,7 +71,7 @@ class ReplayBuffer(object):
|
||||
[ 7. 7. 7. 8.]]
|
||||
>>> # here is another way to get the stacked data
|
||||
>>> # (stack only for obs and obs_next)
|
||||
>>> sum(sum(buf.get_stack(index, 'obs') - buf[index].obs))
|
||||
>>> abs(buf.get_stack(index, 'obs') - buf[index].obs).sum().sum()
|
||||
0.0
|
||||
"""
|
||||
|
||||
@ -200,6 +206,11 @@ class ListReplayBuffer(ReplayBuffer):
|
||||
"""The function of :class:`~tianshou.data.ListReplayBuffer` is almost the
|
||||
same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
|
||||
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.data.ListReplayBuffer` for more
|
||||
detailed explanation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
24
tianshou/env/vecenv.py
vendored
24
tianshou/env/vecenv.py
vendored
@ -94,8 +94,12 @@ class BaseVectorEnv(ABC, gym.Wrapper):
|
||||
|
||||
|
||||
class VectorEnv(BaseVectorEnv):
|
||||
"""Dummy vectorized environment wrapper, implemented in for-loop. The usage
|
||||
is in :class:`~tianshou.env.BaseVectorEnv`.
|
||||
"""Dummy vectorized environment wrapper, implemented in for-loop.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns):
|
||||
@ -170,8 +174,12 @@ def worker(parent, p, env_fn_wrapper):
|
||||
|
||||
|
||||
class SubprocVectorEnv(BaseVectorEnv):
|
||||
"""Vectorized environment wrapper based on subprocess. The usage is in
|
||||
:class:`~tianshou.env.BaseVectorEnv`.
|
||||
"""Vectorized environment wrapper based on subprocess.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns):
|
||||
@ -247,8 +255,12 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
"""Vectorized environment wrapper based on
|
||||
`ray <https://github.com/ray-project/ray>`_. However, according to our
|
||||
test, it is about two times slower than
|
||||
:class:`~tianshou.env.SubprocVectorEnv`. The usage is in
|
||||
:class:`~tianshou.env.BaseVectorEnv`.
|
||||
:class:`~tianshou.env.SubprocVectorEnv`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns):
|
||||
|
@ -1,11 +1,11 @@
|
||||
from tianshou.policy.base import BasePolicy
|
||||
from tianshou.policy.dqn import DQNPolicy
|
||||
from tianshou.policy.pg import PGPolicy
|
||||
from tianshou.policy.a2c import A2CPolicy
|
||||
from tianshou.policy.ddpg import DDPGPolicy
|
||||
from tianshou.policy.ppo import PPOPolicy
|
||||
from tianshou.policy.td3 import TD3Policy
|
||||
from tianshou.policy.sac import SACPolicy
|
||||
from tianshou.policy.modelfree.dqn import DQNPolicy
|
||||
from tianshou.policy.modelfree.pg import PGPolicy
|
||||
from tianshou.policy.modelfree.a2c import A2CPolicy
|
||||
from tianshou.policy.modelfree.ddpg import DDPGPolicy
|
||||
from tianshou.policy.modelfree.ppo import PPOPolicy
|
||||
from tianshou.policy.modelfree.td3 import TD3Policy
|
||||
from tianshou.policy.modelfree.sac import SACPolicy
|
||||
|
||||
__all__ = [
|
||||
'BasePolicy',
|
||||
|
0
tianshou/policy/modelfree/__init__.py
Normal file
0
tianshou/policy/modelfree/__init__.py
Normal file
@ -21,6 +21,11 @@ class A2CPolicy(PGPolicy):
|
||||
:param float ent_coef: weight for entropy loss, defaults to 0.01.
|
||||
:param float max_grad_norm: clipping gradients in back propagation,
|
||||
defaults to ``None``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, actor, critic, optim,
|
||||
@ -44,8 +49,10 @@ class A2CPolicy(PGPolicy):
|
||||
* ``dist`` the action distribution.
|
||||
* ``state`` the hidden state.
|
||||
|
||||
More information can be found at
|
||||
:meth:`~tianshou.policy.BasePolicy.__call__`.
|
||||
.. seealso::
|
||||
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
if isinstance(logits, tuple):
|
@ -28,6 +28,11 @@ class DDPGPolicy(BasePolicy):
|
||||
defaults to ``False``.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
defaults to ``False``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, actor, actor_optim, critic, critic_optim,
|
||||
@ -104,8 +109,10 @@ class DDPGPolicy(BasePolicy):
|
||||
* ``act`` the action.
|
||||
* ``state`` the hidden state.
|
||||
|
||||
More information can be found at
|
||||
:meth:`~tianshou.policy.BasePolicy.__call__`.
|
||||
.. seealso::
|
||||
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
obs = getattr(batch, input)
|
@ -18,6 +18,11 @@ class DQNPolicy(BasePolicy):
|
||||
ahead.
|
||||
:param int target_update_freq: the target network update frequency (``0``
|
||||
if you do not use the target network).
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, model, optim, discount_factor=0.99,
|
||||
@ -106,8 +111,10 @@ class DQNPolicy(BasePolicy):
|
||||
* ``logits`` the network's raw output.
|
||||
* ``state`` the hidden state.
|
||||
|
||||
More information can be found at
|
||||
:meth:`~tianshou.policy.BasePolicy.__call__`.
|
||||
.. seealso::
|
||||
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
obs = getattr(batch, input)
|
@ -13,6 +13,11 @@ class PGPolicy(BasePolicy):
|
||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||
:param torch.distributions.Distribution dist_fn: for computing the action.
|
||||
:param float discount_factor: in [0, 1].
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
||||
@ -50,8 +55,10 @@ class PGPolicy(BasePolicy):
|
||||
* ``dist`` the action distribution.
|
||||
* ``state`` the hidden state.
|
||||
|
||||
More information can be found at
|
||||
:meth:`~tianshou.policy.BasePolicy.__call__`.
|
||||
.. seealso::
|
||||
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
||||
if isinstance(logits, tuple):
|
@ -26,6 +26,11 @@ class PPOPolicy(PGPolicy):
|
||||
:param float ent_coef: weight for entropy loss, defaults to 0.01.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: [float, float]
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, actor, critic, optim, dist_fn,
|
||||
@ -70,8 +75,10 @@ class PPOPolicy(PGPolicy):
|
||||
* ``dist`` the action distribution.
|
||||
* ``state`` the hidden state.
|
||||
|
||||
More information can be found at
|
||||
:meth:`~tianshou.policy.BasePolicy.__call__`.
|
||||
.. seealso::
|
||||
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
logits, h = model(batch.obs, state=state, info=batch.info)
|
@ -33,6 +33,11 @@ class SACPolicy(DDPGPolicy):
|
||||
defaults to ``False``.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
defaults to ``False``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, actor, actor_optim, critic1, critic1_optim,
|
@ -37,6 +37,11 @@ class TD3Policy(DDPGPolicy):
|
||||
defaults to ``False``.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
defaults to ``False``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, actor, actor_optim, critic1, critic1_optim,
|
Loading…
x
Reference in New Issue
Block a user