diff --git a/README.md b/README.md index 6267588..21f0fb4 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ If no error occurs, you have successfully installed Tianshou. ## Documentation -The tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io). +The tutorials and API documentation are hosted on [tianshou.readthedocs.io/en/stable/](https://tianshou.readthedocs.io/en/stable/) (stable version) and [tianshou.readthedocs.io/en/latest/](https://tianshou.readthedocs.io/en/latest/) (develop version). The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/master/test) folder and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folder. @@ -112,8 +112,8 @@ Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page We decouple all of the algorithms into 4 parts: - `__init__`: initialize the policy; +- `forward`: to compute actions over given observations; - `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms); -- `__call__`: to compute actions over given observations; - `learn`: to learn from a given batch data. Within these API, we can interact with different policies conveniently. diff --git a/docs/index.rst b/docs/index.rst index f7cc3df..a5e333e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -42,6 +42,7 @@ After installation, open your python console and type If no error occurs, you have successfully installed Tianshou. +Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ `_. .. toctree:: :maxdepth: 1 @@ -50,7 +51,7 @@ If no error occurs, you have successfully installed Tianshou. tutorials/dqn tutorials/concepts tutorials/tabular - tutorials/trick + tutorials/cheatsheet .. toctree:: :maxdepth: 1 diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst new file mode 100644 index 0000000..1a89645 --- /dev/null +++ b/docs/tutorials/cheatsheet.rst @@ -0,0 +1,6 @@ +Cheat Sheet +=========== + +This page shows some code snippets of how to use Tianshou to develop new algorithms. + +TODO diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index f5dca97..6de15cc 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -35,7 +35,7 @@ Tianshou aims to modularizing RL algorithms. It comes into several classes of po A policy class typically has four parts: * :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on; -* :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given observation; +* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given observation; * :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer (this function can interact with replay buffer); * :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data. @@ -126,18 +126,18 @@ We give a high-level explanation through the pseudocode used in section :ref:`po # pseudocode, cannot work # methods in tianshou s = env.reset() buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000) - agent = DQN() # done in policy.__init__(...) + agent = DQN() # policy.__init__(...) for i in range(int(1e6)): # done in trainer - a = agent.compute_action(s) # done in policy.__call__(batch, ...) - s_, r, d, _ = env.step(a) # done in collector.collect(...) - buffer.store(s, a, s_, r, d) # done in collector.collect(...) - s = s_ # done in collector.collect(...) + a = agent.compute_action(s) # policy(batch, ...) + s_, r, d, _ = env.step(a) # collector.collect(...) + buffer.store(s, a, s_, r, d) # collector.collect(...) + s = s_ # collector.collect(...) if i % 1000 == 0: # done in trainer - b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # done in collector.sample(batch_size) + b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # collector.sample(batch_size) # compute 2-step returns. How? - b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # done in policy.process_fn(batch, buffer, indice) + b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # policy.process_fn(batch, buffer, indice) # update DQN policy - agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # done in policy.learn(batch, ...) + agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # policy.learn(batch, ...) Conclusion diff --git a/docs/tutorials/tabular.rst b/docs/tutorials/tabular.rst deleted file mode 100644 index fba4b2e..0000000 --- a/docs/tutorials/tabular.rst +++ /dev/null @@ -1,11 +0,0 @@ -Tabular Q Learning Implementation -================================= - -This tutorial shows how to use Tianshou to develop new algorithms. - - -Background ----------- - -TODO - diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 5328d1d..ae5e7b6 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -15,7 +15,7 @@ class MyPolicy(BasePolicy): def __init__(self): super().__init__() - def __call__(self, batch, state=None): + def forward(self, batch, state=None): return Batch(act=np.ones(batch.obs.shape[0])) def learn(self): diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index ef34655..8fbcb78 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -11,7 +11,7 @@ class BasePolicy(ABC, nn.Module): * :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, \ including coping the target network and so on; - * :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given \ + * :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \ observation; * :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from \ the replay buffer (this function can interact with replay buffer); @@ -48,7 +48,7 @@ class BasePolicy(ABC, nn.Module): return batch @abstractmethod - def __call__(self, batch, state=None, **kwargs): + def forward(self, batch, state=None, **kwargs): """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which MUST have the following\ diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 534c055..09937b4 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -39,7 +39,7 @@ class A2CPolicy(PGPolicy): self._w_ent = ent_coef self._grad_norm = max_grad_norm - def __call__(self, batch, state=None, **kwargs): + def forward(self, batch, state=None, **kwargs): """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 4 keys: @@ -51,7 +51,7 @@ class A2CPolicy(PGPolicy): .. seealso:: - Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ logits, h = self.actor(batch.obs, state=state, info=batch.info) diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 40bc3c1..503b39d 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -98,8 +98,8 @@ class DDPGPolicy(BasePolicy): batch.done = batch.done * 0. return batch - def __call__(self, batch, state=None, - model='actor', input='obs', eps=None, **kwargs): + def forward(self, batch, state=None, + model='actor', input='obs', eps=None, **kwargs): """Compute action over the given batch data. :param float eps: in [0, 1], for exploration use. @@ -111,7 +111,7 @@ class DDPGPolicy(BasePolicy): .. seealso:: - Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ model = getattr(self, model) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 08e2341..3dd0687 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -100,8 +100,8 @@ class DQNPolicy(BasePolicy): batch.returns = returns return batch - def __call__(self, batch, state=None, - model='model', input='obs', eps=None, **kwargs): + def forward(self, batch, state=None, + model='model', input='obs', eps=None, **kwargs): """Compute action over the given batch data. :param float eps: in [0, 1], for epsilon-greedy exploration method. @@ -114,7 +114,7 @@ class DQNPolicy(BasePolicy): .. seealso:: - Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ model = getattr(self, model) diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 75f3023..1966bcc 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -43,7 +43,7 @@ class PGPolicy(BasePolicy): # batch.returns = self._vectorized_returns(batch) return batch - def __call__(self, batch, state=None, **kwargs): + def forward(self, batch, state=None, **kwargs): """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 4 keys: @@ -55,7 +55,7 @@ class PGPolicy(BasePolicy): .. seealso:: - Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ logits, h = self.model(batch.obs, state=state, info=batch.info) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 32200d2..ba64a6b 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -65,7 +65,7 @@ class PPOPolicy(PGPolicy): self.actor.eval() self.critic.eval() - def __call__(self, batch, state=None, model='actor', **kwargs): + def forward(self, batch, state=None, model='actor', **kwargs): """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 4 keys: @@ -77,7 +77,7 @@ class PPOPolicy(PGPolicy): .. seealso:: - Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ model = getattr(self, model) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 8046382..8b27fde 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -77,7 +77,7 @@ class SACPolicy(DDPGPolicy): self.critic2_old.parameters(), self.critic2.parameters()): o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) - def __call__(self, batch, state=None, input='obs', **kwargs): + def forward(self, batch, state=None, input='obs', **kwargs): obs = getattr(batch, input) logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple)