__call__ -> forward
This commit is contained in:
parent
13086b7f64
commit
3cc22b7c0c
@ -57,7 +57,7 @@ If no error occurs, you have successfully installed Tianshou.
|
||||
|
||||
## Documentation
|
||||
|
||||
The tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io).
|
||||
The tutorials and API documentation are hosted on [tianshou.readthedocs.io/en/stable/](https://tianshou.readthedocs.io/en/stable/) (stable version) and [tianshou.readthedocs.io/en/latest/](https://tianshou.readthedocs.io/en/latest/) (develop version).
|
||||
|
||||
The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/master/test) folder and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folder.
|
||||
|
||||
@ -112,8 +112,8 @@ Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page
|
||||
We decouple all of the algorithms into 4 parts:
|
||||
|
||||
- `__init__`: initialize the policy;
|
||||
- `forward`: to compute actions over given observations;
|
||||
- `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);
|
||||
- `__call__`: to compute actions over given observations;
|
||||
- `learn`: to learn from a given batch data.
|
||||
|
||||
Within these API, we can interact with different policies conveniently.
|
||||
|
@ -42,6 +42,7 @@ After installation, open your python console and type
|
||||
|
||||
If no error occurs, you have successfully installed Tianshou.
|
||||
|
||||
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
@ -50,7 +51,7 @@ If no error occurs, you have successfully installed Tianshou.
|
||||
tutorials/dqn
|
||||
tutorials/concepts
|
||||
tutorials/tabular
|
||||
tutorials/trick
|
||||
tutorials/cheatsheet
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
6
docs/tutorials/cheatsheet.rst
Normal file
6
docs/tutorials/cheatsheet.rst
Normal file
@ -0,0 +1,6 @@
|
||||
Cheat Sheet
|
||||
===========
|
||||
|
||||
This page shows some code snippets of how to use Tianshou to develop new algorithms.
|
||||
|
||||
TODO
|
@ -35,7 +35,7 @@ Tianshou aims to modularizing RL algorithms. It comes into several classes of po
|
||||
A policy class typically has four parts:
|
||||
|
||||
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on;
|
||||
* :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given observation;
|
||||
* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given observation;
|
||||
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer (this function can interact with replay buffer);
|
||||
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data.
|
||||
|
||||
@ -126,18 +126,18 @@ We give a high-level explanation through the pseudocode used in section :ref:`po
|
||||
# pseudocode, cannot work # methods in tianshou
|
||||
s = env.reset()
|
||||
buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000)
|
||||
agent = DQN() # done in policy.__init__(...)
|
||||
agent = DQN() # policy.__init__(...)
|
||||
for i in range(int(1e6)): # done in trainer
|
||||
a = agent.compute_action(s) # done in policy.__call__(batch, ...)
|
||||
s_, r, d, _ = env.step(a) # done in collector.collect(...)
|
||||
buffer.store(s, a, s_, r, d) # done in collector.collect(...)
|
||||
s = s_ # done in collector.collect(...)
|
||||
a = agent.compute_action(s) # policy(batch, ...)
|
||||
s_, r, d, _ = env.step(a) # collector.collect(...)
|
||||
buffer.store(s, a, s_, r, d) # collector.collect(...)
|
||||
s = s_ # collector.collect(...)
|
||||
if i % 1000 == 0: # done in trainer
|
||||
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # done in collector.sample(batch_size)
|
||||
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # collector.sample(batch_size)
|
||||
# compute 2-step returns. How?
|
||||
b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # done in policy.process_fn(batch, buffer, indice)
|
||||
b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # policy.process_fn(batch, buffer, indice)
|
||||
# update DQN policy
|
||||
agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # done in policy.learn(batch, ...)
|
||||
agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # policy.learn(batch, ...)
|
||||
|
||||
|
||||
Conclusion
|
||||
|
@ -1,11 +0,0 @@
|
||||
Tabular Q Learning Implementation
|
||||
=================================
|
||||
|
||||
This tutorial shows how to use Tianshou to develop new algorithms.
|
||||
|
||||
|
||||
Background
|
||||
----------
|
||||
|
||||
TODO
|
||||
|
@ -15,7 +15,7 @@ class MyPolicy(BasePolicy):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, batch, state=None):
|
||||
def forward(self, batch, state=None):
|
||||
return Batch(act=np.ones(batch.obs.shape[0]))
|
||||
|
||||
def learn(self):
|
||||
|
@ -11,7 +11,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
|
||||
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, \
|
||||
including coping the target network and so on;
|
||||
* :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given \
|
||||
* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \
|
||||
observation;
|
||||
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from \
|
||||
the replay buffer (this function can interact with replay buffer);
|
||||
@ -48,7 +48,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
return batch
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, batch, state=None, **kwargs):
|
||||
def forward(self, batch, state=None, **kwargs):
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which MUST have the following\
|
||||
|
@ -39,7 +39,7 @@ class A2CPolicy(PGPolicy):
|
||||
self._w_ent = ent_coef
|
||||
self._grad_norm = max_grad_norm
|
||||
|
||||
def __call__(self, batch, state=None, **kwargs):
|
||||
def forward(self, batch, state=None, **kwargs):
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||
@ -51,7 +51,7 @@ class A2CPolicy(PGPolicy):
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
|
@ -98,8 +98,8 @@ class DDPGPolicy(BasePolicy):
|
||||
batch.done = batch.done * 0.
|
||||
return batch
|
||||
|
||||
def __call__(self, batch, state=None,
|
||||
model='actor', input='obs', eps=None, **kwargs):
|
||||
def forward(self, batch, state=None,
|
||||
model='actor', input='obs', eps=None, **kwargs):
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:param float eps: in [0, 1], for exploration use.
|
||||
@ -111,7 +111,7 @@ class DDPGPolicy(BasePolicy):
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
|
@ -100,8 +100,8 @@ class DQNPolicy(BasePolicy):
|
||||
batch.returns = returns
|
||||
return batch
|
||||
|
||||
def __call__(self, batch, state=None,
|
||||
model='model', input='obs', eps=None, **kwargs):
|
||||
def forward(self, batch, state=None,
|
||||
model='model', input='obs', eps=None, **kwargs):
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:param float eps: in [0, 1], for epsilon-greedy exploration method.
|
||||
@ -114,7 +114,7 @@ class DQNPolicy(BasePolicy):
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
|
@ -43,7 +43,7 @@ class PGPolicy(BasePolicy):
|
||||
# batch.returns = self._vectorized_returns(batch)
|
||||
return batch
|
||||
|
||||
def __call__(self, batch, state=None, **kwargs):
|
||||
def forward(self, batch, state=None, **kwargs):
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||
@ -55,7 +55,7 @@ class PGPolicy(BasePolicy):
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
||||
|
@ -65,7 +65,7 @@ class PPOPolicy(PGPolicy):
|
||||
self.actor.eval()
|
||||
self.critic.eval()
|
||||
|
||||
def __call__(self, batch, state=None, model='actor', **kwargs):
|
||||
def forward(self, batch, state=None, model='actor', **kwargs):
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||
@ -77,7 +77,7 @@ class PPOPolicy(PGPolicy):
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.__call__` for
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
|
@ -77,7 +77,7 @@ class SACPolicy(DDPGPolicy):
|
||||
self.critic2_old.parameters(), self.critic2.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
|
||||
def __call__(self, batch, state=None, input='obs', **kwargs):
|
||||
def forward(self, batch, state=None, input='obs', **kwargs):
|
||||
obs = getattr(batch, input)
|
||||
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||
assert isinstance(logits, tuple)
|
||||
|
Loading…
x
Reference in New Issue
Block a user