diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 376e7d2..eca521c 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -15,10 +15,12 @@ class Batch(object): >>> data.b [5, 5] >>> data.b = np.array([3, 4, 5]) - >>> len(data.b) - 3 - >>> data.b[-1] - 5 + >>> print(data) + Batch( + a: 4, + b: [3 4 5], + c: 2312312, + ) In short, you can define a :class:`Batch` with any key-value pair. The current implementation of Tianshou typically use 6 keys in diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 1146cd8..c5edbfe 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -9,6 +9,7 @@ class ReplayBuffer(object): ``numpy.ndarray``. Here is the usage: :: + >>> import numpy as np >>> from tianshou.data import ReplayBuffer >>> buf = ReplayBuffer(size=20) >>> for i in range(3): @@ -48,10 +49,15 @@ class ReplayBuffer(object): >>> for i in range(16): ... done = i % 5 == 0 ... buf.add(obs=i, act=i, rew=i, done=done, obs_next=0, info={}) - >>> print(buf.obs) - [ 9. 10. 11. 12. 13. 14. 15. 7. 8.] - >>> print(buf.done) - [0. 1. 0. 0. 0. 0. 1. 0. 0.] + >>> print(buf) + ReplayBuffer( + obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.], + act: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.], + rew: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.], + done: [0. 1. 0. 0. 0. 0. 1. 0. 0.], + obs_next: [0. 0. 0. 0. 0. 0. 0. 0. 0.], + info: [{} {} {} {} {} {} {} {} {}], + ) >>> index = np.arange(len(buf)) >>> print(buf.get_stack(index, 'obs')) [[ 7. 7. 8. 9.] @@ -65,7 +71,7 @@ class ReplayBuffer(object): [ 7. 7. 7. 8.]] >>> # here is another way to get the stacked data >>> # (stack only for obs and obs_next) - >>> sum(sum(buf.get_stack(index, 'obs') - buf[index].obs)) + >>> abs(buf.get_stack(index, 'obs') - buf[index].obs).sum().sum() 0.0 """ @@ -200,6 +206,11 @@ class ListReplayBuffer(ReplayBuffer): """The function of :class:`~tianshou.data.ListReplayBuffer` is almost the same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that :class:`~tianshou.data.ListReplayBuffer` is based on ``list``. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ListReplayBuffer` for more + detailed explanation. """ def __init__(self): diff --git a/tianshou/env/vecenv.py b/tianshou/env/vecenv.py index 79abbc3..5f7225a 100644 --- a/tianshou/env/vecenv.py +++ b/tianshou/env/vecenv.py @@ -94,8 +94,12 @@ class BaseVectorEnv(ABC, gym.Wrapper): class VectorEnv(BaseVectorEnv): - """Dummy vectorized environment wrapper, implemented in for-loop. The usage - is in :class:`~tianshou.env.BaseVectorEnv`. + """Dummy vectorized environment wrapper, implemented in for-loop. + + .. seealso:: + + Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed + explanation. """ def __init__(self, env_fns): @@ -170,8 +174,12 @@ def worker(parent, p, env_fn_wrapper): class SubprocVectorEnv(BaseVectorEnv): - """Vectorized environment wrapper based on subprocess. The usage is in - :class:`~tianshou.env.BaseVectorEnv`. + """Vectorized environment wrapper based on subprocess. + + .. seealso:: + + Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed + explanation. """ def __init__(self, env_fns): @@ -247,8 +255,12 @@ class RayVectorEnv(BaseVectorEnv): """Vectorized environment wrapper based on `ray `_. However, according to our test, it is about two times slower than - :class:`~tianshou.env.SubprocVectorEnv`. The usage is in - :class:`~tianshou.env.BaseVectorEnv`. + :class:`~tianshou.env.SubprocVectorEnv`. + + .. seealso:: + + Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed + explanation. """ def __init__(self, env_fns): diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index c48d96e..1bf4bb6 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,11 +1,11 @@ from tianshou.policy.base import BasePolicy -from tianshou.policy.dqn import DQNPolicy -from tianshou.policy.pg import PGPolicy -from tianshou.policy.a2c import A2CPolicy -from tianshou.policy.ddpg import DDPGPolicy -from tianshou.policy.ppo import PPOPolicy -from tianshou.policy.td3 import TD3Policy -from tianshou.policy.sac import SACPolicy +from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.pg import PGPolicy +from tianshou.policy.modelfree.a2c import A2CPolicy +from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.modelfree.ppo import PPOPolicy +from tianshou.policy.modelfree.td3 import TD3Policy +from tianshou.policy.modelfree.sac import SACPolicy __all__ = [ 'BasePolicy', diff --git a/tianshou/policy/modelfree/__init__.py b/tianshou/policy/modelfree/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tianshou/policy/a2c.py b/tianshou/policy/modelfree/a2c.py similarity index 93% rename from tianshou/policy/a2c.py rename to tianshou/policy/modelfree/a2c.py index 9ee807f..534c055 100644 --- a/tianshou/policy/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -21,6 +21,11 @@ class A2CPolicy(PGPolicy): :param float ent_coef: weight for entropy loss, defaults to 0.01. :param float max_grad_norm: clipping gradients in back propagation, defaults to ``None``. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. """ def __init__(self, actor, critic, optim, @@ -44,8 +49,10 @@ class A2CPolicy(PGPolicy): * ``dist`` the action distribution. * ``state`` the hidden state. - More information can be found at - :meth:`~tianshou.policy.BasePolicy.__call__`. + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for + more detailed explanation. """ logits, h = self.actor(batch.obs, state=state, info=batch.info) if isinstance(logits, tuple): diff --git a/tianshou/policy/ddpg.py b/tianshou/policy/modelfree/ddpg.py similarity index 95% rename from tianshou/policy/ddpg.py rename to tianshou/policy/modelfree/ddpg.py index b07f64d..40bc3c1 100644 --- a/tianshou/policy/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -28,6 +28,11 @@ class DDPGPolicy(BasePolicy): defaults to ``False``. :param bool ignore_done: ignore the done flag while training the policy, defaults to ``False``. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. """ def __init__(self, actor, actor_optim, critic, critic_optim, @@ -104,8 +109,10 @@ class DDPGPolicy(BasePolicy): * ``act`` the action. * ``state`` the hidden state. - More information can be found at - :meth:`~tianshou.policy.BasePolicy.__call__`. + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for + more detailed explanation. """ model = getattr(self, model) obs = getattr(batch, input) diff --git a/tianshou/policy/dqn.py b/tianshou/policy/modelfree/dqn.py similarity index 95% rename from tianshou/policy/dqn.py rename to tianshou/policy/modelfree/dqn.py index bd370bb..b6e7bbf 100644 --- a/tianshou/policy/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -18,6 +18,11 @@ class DQNPolicy(BasePolicy): ahead. :param int target_update_freq: the target network update frequency (``0`` if you do not use the target network). + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. """ def __init__(self, model, optim, discount_factor=0.99, @@ -106,8 +111,10 @@ class DQNPolicy(BasePolicy): * ``logits`` the network's raw output. * ``state`` the hidden state. - More information can be found at - :meth:`~tianshou.policy.BasePolicy.__call__`. + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for + more detailed explanation. """ model = getattr(self, model) obs = getattr(batch, input) diff --git a/tianshou/policy/pg.py b/tianshou/policy/modelfree/pg.py similarity index 93% rename from tianshou/policy/pg.py rename to tianshou/policy/modelfree/pg.py index 63c05c7..9f82d5a 100644 --- a/tianshou/policy/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -13,6 +13,11 @@ class PGPolicy(BasePolicy): :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. :param torch.distributions.Distribution dist_fn: for computing the action. :param float discount_factor: in [0, 1]. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. """ def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, @@ -50,8 +55,10 @@ class PGPolicy(BasePolicy): * ``dist`` the action distribution. * ``state`` the hidden state. - More information can be found at - :meth:`~tianshou.policy.BasePolicy.__call__`. + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for + more detailed explanation. """ logits, h = self.model(batch.obs, state=state, info=batch.info) if isinstance(logits, tuple): diff --git a/tianshou/policy/ppo.py b/tianshou/policy/modelfree/ppo.py similarity index 95% rename from tianshou/policy/ppo.py rename to tianshou/policy/modelfree/ppo.py index a75eddb..32200d2 100644 --- a/tianshou/policy/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -26,6 +26,11 @@ class PPOPolicy(PGPolicy): :param float ent_coef: weight for entropy loss, defaults to 0.01. :param action_range: the action range (minimum, maximum). :type action_range: [float, float] + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. """ def __init__(self, actor, critic, optim, dist_fn, @@ -70,8 +75,10 @@ class PPOPolicy(PGPolicy): * ``dist`` the action distribution. * ``state`` the hidden state. - More information can be found at - :meth:`~tianshou.policy.BasePolicy.__call__`. + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for + more detailed explanation. """ model = getattr(self, model) logits, h = model(batch.obs, state=state, info=batch.info) diff --git a/tianshou/policy/sac.py b/tianshou/policy/modelfree/sac.py similarity index 97% rename from tianshou/policy/sac.py rename to tianshou/policy/modelfree/sac.py index 80cfe52..8046382 100644 --- a/tianshou/policy/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -33,6 +33,11 @@ class SACPolicy(DDPGPolicy): defaults to ``False``. :param bool ignore_done: ignore the done flag while training the policy, defaults to ``False``. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. """ def __init__(self, actor, actor_optim, critic1, critic1_optim, diff --git a/tianshou/policy/td3.py b/tianshou/policy/modelfree/td3.py similarity index 97% rename from tianshou/policy/td3.py rename to tianshou/policy/modelfree/td3.py index 4d593d2..349940d 100644 --- a/tianshou/policy/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -37,6 +37,11 @@ class TD3Policy(DDPGPolicy): defaults to ``False``. :param bool ignore_done: ignore the done flag while training the policy, defaults to ``False``. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. """ def __init__(self, actor, actor_optim, critic1, critic1_optim,