__call__ -> forward

This commit is contained in:
Trinkle23897 2020-04-10 10:47:16 +08:00
parent 13086b7f64
commit 3cc22b7c0c
13 changed files with 35 additions and 39 deletions

View File

@ -57,7 +57,7 @@ If no error occurs, you have successfully installed Tianshou.
## Documentation
The tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io).
The tutorials and API documentation are hosted on [tianshou.readthedocs.io/en/stable/](https://tianshou.readthedocs.io/en/stable/) (stable version) and [tianshou.readthedocs.io/en/latest/](https://tianshou.readthedocs.io/en/latest/) (develop version).
The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/master/test) folder and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folder.
@ -112,8 +112,8 @@ Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page
We decouple all of the algorithms into 4 parts:
- `__init__`: initialize the policy;
- `forward`: to compute actions over given observations;
- `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);
- `__call__`: to compute actions over given observations;
- `learn`: to learn from a given batch data.
Within these API, we can interact with different policies conveniently.

View File

@ -42,6 +42,7 @@ After installation, open your python console and type
If no error occurs, you have successfully installed Tianshou.
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_.
.. toctree::
:maxdepth: 1
@ -50,7 +51,7 @@ If no error occurs, you have successfully installed Tianshou.
tutorials/dqn
tutorials/concepts
tutorials/tabular
tutorials/trick
tutorials/cheatsheet
.. toctree::
:maxdepth: 1

View File

@ -0,0 +1,6 @@
Cheat Sheet
===========
This page shows some code snippets of how to use Tianshou to develop new algorithms.
TODO

View File

@ -35,7 +35,7 @@ Tianshou aims to modularizing RL algorithms. It comes into several classes of po
A policy class typically has four parts:
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on;
* :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given observation;
* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given observation;
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer (this function can interact with replay buffer);
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data.
@ -126,18 +126,18 @@ We give a high-level explanation through the pseudocode used in section :ref:`po
# pseudocode, cannot work # methods in tianshou
s = env.reset()
buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000)
agent = DQN() # done in policy.__init__(...)
agent = DQN() # policy.__init__(...)
for i in range(int(1e6)): # done in trainer
a = agent.compute_action(s) # done in policy.__call__(batch, ...)
s_, r, d, _ = env.step(a) # done in collector.collect(...)
buffer.store(s, a, s_, r, d) # done in collector.collect(...)
s = s_ # done in collector.collect(...)
a = agent.compute_action(s) # policy(batch, ...)
s_, r, d, _ = env.step(a) # collector.collect(...)
buffer.store(s, a, s_, r, d) # collector.collect(...)
s = s_ # collector.collect(...)
if i % 1000 == 0: # done in trainer
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # done in collector.sample(batch_size)
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # collector.sample(batch_size)
# compute 2-step returns. How?
b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # done in policy.process_fn(batch, buffer, indice)
b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # policy.process_fn(batch, buffer, indice)
# update DQN policy
agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # done in policy.learn(batch, ...)
agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # policy.learn(batch, ...)
Conclusion

View File

@ -1,11 +0,0 @@
Tabular Q Learning Implementation
=================================
This tutorial shows how to use Tianshou to develop new algorithms.
Background
----------
TODO

View File

@ -15,7 +15,7 @@ class MyPolicy(BasePolicy):
def __init__(self):
super().__init__()
def __call__(self, batch, state=None):
def forward(self, batch, state=None):
return Batch(act=np.ones(batch.obs.shape[0]))
def learn(self):

View File

@ -11,7 +11,7 @@ class BasePolicy(ABC, nn.Module):
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, \
including coping the target network and so on;
* :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given \
* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \
observation;
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from \
the replay buffer (this function can interact with replay buffer);
@ -48,7 +48,7 @@ class BasePolicy(ABC, nn.Module):
return batch
@abstractmethod
def __call__(self, batch, state=None, **kwargs):
def forward(self, batch, state=None, **kwargs):
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which MUST have the following\

View File

@ -39,7 +39,7 @@ class A2CPolicy(PGPolicy):
self._w_ent = ent_coef
self._grad_norm = max_grad_norm
def __call__(self, batch, state=None, **kwargs):
def forward(self, batch, state=None, **kwargs):
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
@ -51,7 +51,7 @@ class A2CPolicy(PGPolicy):
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
logits, h = self.actor(batch.obs, state=state, info=batch.info)

View File

@ -98,8 +98,8 @@ class DDPGPolicy(BasePolicy):
batch.done = batch.done * 0.
return batch
def __call__(self, batch, state=None,
model='actor', input='obs', eps=None, **kwargs):
def forward(self, batch, state=None,
model='actor', input='obs', eps=None, **kwargs):
"""Compute action over the given batch data.
:param float eps: in [0, 1], for exploration use.
@ -111,7 +111,7 @@ class DDPGPolicy(BasePolicy):
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
model = getattr(self, model)

View File

@ -100,8 +100,8 @@ class DQNPolicy(BasePolicy):
batch.returns = returns
return batch
def __call__(self, batch, state=None,
model='model', input='obs', eps=None, **kwargs):
def forward(self, batch, state=None,
model='model', input='obs', eps=None, **kwargs):
"""Compute action over the given batch data.
:param float eps: in [0, 1], for epsilon-greedy exploration method.
@ -114,7 +114,7 @@ class DQNPolicy(BasePolicy):
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
model = getattr(self, model)

View File

@ -43,7 +43,7 @@ class PGPolicy(BasePolicy):
# batch.returns = self._vectorized_returns(batch)
return batch
def __call__(self, batch, state=None, **kwargs):
def forward(self, batch, state=None, **kwargs):
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
@ -55,7 +55,7 @@ class PGPolicy(BasePolicy):
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
logits, h = self.model(batch.obs, state=state, info=batch.info)

View File

@ -65,7 +65,7 @@ class PPOPolicy(PGPolicy):
self.actor.eval()
self.critic.eval()
def __call__(self, batch, state=None, model='actor', **kwargs):
def forward(self, batch, state=None, model='actor', **kwargs):
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
@ -77,7 +77,7 @@ class PPOPolicy(PGPolicy):
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
model = getattr(self, model)

View File

@ -77,7 +77,7 @@ class SACPolicy(DDPGPolicy):
self.critic2_old.parameters(), self.critic2.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
def __call__(self, batch, state=None, input='obs', **kwargs):
def forward(self, batch, state=None, input='obs', **kwargs):
obs = getattr(batch, input)
logits, h = self.actor(obs, state=state, info=batch.info)
assert isinstance(logits, tuple)