cheat sheet
This commit is contained in:
parent
52be533d06
commit
560116d0b2
@ -1,6 +1,192 @@
|
||||
Cheat Sheet
|
||||
===========
|
||||
|
||||
This page shows some code snippets of how to use Tianshou to develop new algorithms.
|
||||
This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios.
|
||||
|
||||
TODO
|
||||
By the way, some of these issues can be resolved by using a ``gym.wrapper``. It can be a solution in the policy-environment interaction.
|
||||
|
||||
.. _network_api:
|
||||
|
||||
Build Policy Network
|
||||
--------------------
|
||||
|
||||
See :ref:`build_the_network`.
|
||||
|
||||
.. _new_policy:
|
||||
|
||||
Build New Policy
|
||||
----------------
|
||||
|
||||
See :class:`~tianshou.policy.BasePolicy`.
|
||||
|
||||
.. _parallel_sampling:
|
||||
|
||||
Customize Training Process
|
||||
--------------------------
|
||||
|
||||
See :ref:`customized_trainer`.
|
||||
|
||||
Parallel Sampling
|
||||
-----------------
|
||||
|
||||
Use :class:`~tianshou.env.VectorEnv` or :class:`~tianshou.env.SubprocVectorEnv`.
|
||||
::
|
||||
|
||||
env_fns = [
|
||||
lambda: MyTestEnv(size=2),
|
||||
lambda: MyTestEnv(size=3),
|
||||
lambda: MyTestEnv(size=4),
|
||||
lambda: MyTestEnv(size=5),
|
||||
]
|
||||
venv = SubprocVectorEnv(env_fns)
|
||||
|
||||
where ``env_fns`` is a list of callable env hooker. The above code can be written in for-loop as well:
|
||||
::
|
||||
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]]
|
||||
venv = SubprocVectorEnv(env_fns)
|
||||
|
||||
.. _rnn_training:
|
||||
|
||||
RNN-style Training
|
||||
------------------
|
||||
|
||||
This is related to `Issue 19 <https://github.com/thu-ml/tianshou/issues/19>`_.
|
||||
|
||||
First, add an argument ``stack_num`` to :class:`~tianshou.data.ReplayBuffer`:
|
||||
::
|
||||
|
||||
buf = ReplayBuffer(size=size, stack_num=stack_num)
|
||||
|
||||
Then, change the network to recurrent-style, for example, class ``Recurrent`` in `code snippet 1 <https://github.com/thu-ml/tianshou/blob/master/test/discrete/net.py>`_, or ``RecurrentActor`` and ``RecurrentCritic`` in `code snippet 2 <https://github.com/thu-ml/tianshou/blob/master/test/continuous/net.py>`_.
|
||||
|
||||
The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.wrapper`` to modify the state representation. For example, if we add a wrapper that map [s, a] pair to a new state:
|
||||
|
||||
- Before: (s, a, s', r, d) stored in replay buffer, and get stacked s;
|
||||
- After applying wrapper: ([s, a], a, [s', a'], r, d) stored in replay buffer, and get both stacked s and a.
|
||||
|
||||
.. _self_defined_env:
|
||||
|
||||
User-defined Environment and Different State Representation
|
||||
-----------------------------------------------------------
|
||||
|
||||
This is related to `Issue 38 <https://github.com/thu-ml/tianshou/issues/38>`_ and `Issue 69 <https://github.com/thu-ml/tianshou/issues/69>`_.
|
||||
|
||||
First of all, your self-defined environment must follow the Gym's API, some of them are listed below:
|
||||
|
||||
- reset() -> state
|
||||
|
||||
- step(action) -> state, reward, done, info
|
||||
|
||||
- seed(s) -> None
|
||||
|
||||
- render(mode) -> None
|
||||
|
||||
- close() -> None
|
||||
|
||||
The state can be a ``numpy.ndarray`` or a Python dictionary. Take ``FetchReach-v1`` as an example:
|
||||
::
|
||||
|
||||
>>> e = gym.make('FetchReach-v1')
|
||||
>>> e.reset()
|
||||
{'observation': array([ 1.34183265e+00, 7.49100387e-01, 5.34722720e-01, 1.97805133e-04,
|
||||
7.15193042e-05, 7.73933014e-06, 5.51992816e-08, -2.42927453e-06,
|
||||
4.73325650e-06, -2.28455228e-06]),
|
||||
'achieved_goal': array([1.34183265, 0.74910039, 0.53472272]),
|
||||
'desired_goal': array([1.24073906, 0.77753463, 0.63457791])}
|
||||
|
||||
It shows that the state is a dictionary which has 3 keys. It will stored in :class:`~tianshou.data.ReplayBuffer` as:
|
||||
::
|
||||
|
||||
>>> from tianshou.data import ReplayBuffer
|
||||
>>> b = ReplayBuffer(size=3)
|
||||
>>> b.add(obs=e.reset(), act=0, rew=0, done=0)
|
||||
>>> print(b)
|
||||
ReplayBuffer(
|
||||
act: array([0, 0, 0]),
|
||||
done: array([0, 0, 0]),
|
||||
info: Batch(),
|
||||
obs: Batch(
|
||||
achieved_goal: array([[1.34183265, 0.74910039, 0.53472272],
|
||||
[0. , 0. , 0. ],
|
||||
[0. , 0. , 0. ]]),
|
||||
desired_goal: array([[1.42154265, 0.62505137, 0.62929863],
|
||||
[0. , 0. , 0. ],
|
||||
[0. , 0. , 0. ]]),
|
||||
observation: array([[ 1.34183265e+00, 7.49100387e-01, 5.34722720e-01,
|
||||
1.97805133e-04, 7.15193042e-05, 7.73933014e-06,
|
||||
5.51992816e-08, -2.42927453e-06, 4.73325650e-06,
|
||||
-2.28455228e-06],
|
||||
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
|
||||
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
|
||||
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
|
||||
0.00000000e+00],
|
||||
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
|
||||
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
|
||||
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
|
||||
0.00000000e+00]]),
|
||||
),
|
||||
policy: Batch(),
|
||||
rew: array([0, 0, 0]),
|
||||
)
|
||||
>>> print(b.obs.achieved_goal)
|
||||
[[1.34183265 0.74910039 0.53472272]
|
||||
[0. 0. 0. ]
|
||||
[0. 0. 0. ]]
|
||||
|
||||
And the data batch sampled from this replay buffer:
|
||||
::
|
||||
|
||||
>>> batch, indice = b.sample(2)
|
||||
>>> batch.keys()
|
||||
['act', 'done', 'info', 'obs', 'obs_next', 'policy', 'rew']
|
||||
>>> batch.obs[-1]
|
||||
Batch(
|
||||
achieved_goal: array([1.34183265, 0.74910039, 0.53472272]),
|
||||
desired_goal: array([1.42154265, 0.62505137, 0.62929863]),
|
||||
observation: array([ 1.34183265e+00, 7.49100387e-01, 5.34722720e-01, 1.97805133e-04,
|
||||
7.15193042e-05, 7.73933014e-06, 5.51992816e-08, -2.42927453e-06,
|
||||
4.73325650e-06, -2.28455228e-06]),
|
||||
)
|
||||
>>> batch.obs.desired_goal[-1] # recommended
|
||||
array([1.42154265, 0.62505137, 0.62929863])
|
||||
>>> batch.obs[-1].desired_goal # not recommended
|
||||
array([1.42154265, 0.62505137, 0.62929863])
|
||||
>>> batch[-1].obs.desired_goal # not recommended
|
||||
array([1.42154265, 0.62505137, 0.62929863])
|
||||
|
||||
Thus, in your self-defined network, just change the ``forward`` function as:
|
||||
::
|
||||
|
||||
def forward(self, s, ...):
|
||||
# s is a batch
|
||||
observation = s.observation
|
||||
achieved_goal = s.achieved_goal
|
||||
desired_goal = s.desired_goal
|
||||
...
|
||||
|
||||
For self-defined class, the replay buffer will store the reference into a ``numpy.ndarray``, e.g.:
|
||||
::
|
||||
|
||||
>>> import networkx as nx
|
||||
>>> b = ReplayBuffer(size=3)
|
||||
>>> b.add(obs=nx.Graph(), act=0, rew=0, done=0)
|
||||
>>> print(b)
|
||||
ReplayBuffer(
|
||||
act: array([0, 0, 0]),
|
||||
done: array([0, 0, 0]),
|
||||
info: Batch(),
|
||||
obs: array([<networkx.classes.graph.Graph object at 0x7f5c607826a0>, None,
|
||||
None], dtype=object),
|
||||
policy: Batch(),
|
||||
rew: array([0, 0, 0]),
|
||||
)
|
||||
|
||||
But the state stored in the buffer may be a shallow-copy. To make sure each of your state stored in the buffer is distinct, please return the deep-copy version of your state in your env:
|
||||
::
|
||||
|
||||
def reset():
|
||||
return copy.deepcopy(self.graph)
|
||||
def step(a):
|
||||
...
|
||||
return copy.deepcopy(self.graph), reward, done, {}
|
||||
|
@ -42,6 +42,7 @@ Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_
|
||||
|
||||
For the demonstration, here we use the second block of codes.
|
||||
|
||||
.. _build_the_network:
|
||||
|
||||
Build the Network
|
||||
-----------------
|
||||
@ -75,8 +76,8 @@ Tianshou supports any user-defined PyTorch networks and optimizers but with the
|
||||
|
||||
The rules of self-defined networks are:
|
||||
|
||||
1. Input: observation ``obs`` (may be a ``numpy.ndarray`` or ``torch.Tensor``), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment.
|
||||
2. Output: some ``logits`` and the next hidden state ``state``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy.
|
||||
1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment.
|
||||
2. Output: some ``logits``, the next hidden state ``state``, and intermediate result during the policy forwarding procedure ``policy``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. The ``policy`` can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in ``policy.learn()``, the ``batch.policy`` is what you need).
|
||||
|
||||
|
||||
Setup Policy
|
||||
|
@ -17,7 +17,7 @@ def test_replaybuffer(size=10, bufsize=20):
|
||||
obs_next, rew, done, info = env.step(a)
|
||||
buf.add(obs, a, rew, done, obs_next, info)
|
||||
obs = obs_next
|
||||
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
|
||||
assert len(buf) == min(bufsize, i + 1)
|
||||
data, indice = buf.sample(bufsize * 2)
|
||||
assert (indice < len(buf)).all()
|
||||
assert (data.obs < size).all()
|
||||
@ -40,10 +40,10 @@ def test_stack(size=5, bufsize=9, stack_num=4):
|
||||
if done:
|
||||
obs = env.reset(1)
|
||||
indice = np.arange(len(buf))
|
||||
assert abs(buf.get(indice, 'obs') - np.array([
|
||||
assert np.allclose(buf.get(indice, 'obs'), np.array([
|
||||
[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
|
||||
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
||||
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])).sum() < 1e-6
|
||||
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]]))
|
||||
print(buf)
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
|
||||
assert len(data) == len(buf)
|
||||
else:
|
||||
assert len(data) == len(buf) // 2
|
||||
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
|
||||
assert len(buf) == min(bufsize, i + 1)
|
||||
assert np.isclose(buf._weight_sum, (buf.weight).sum())
|
||||
data, indice = buf.sample(len(buf) // 2)
|
||||
buf.update_weight(indice, -data.weight / 2)
|
||||
|
@ -25,10 +25,6 @@ class MyPolicy(BasePolicy):
|
||||
pass
|
||||
|
||||
|
||||
def equal(a, b):
|
||||
return abs(np.array(a) - np.array(b)).sum() < 1e-6
|
||||
|
||||
|
||||
def preprocess_fn(**kwargs):
|
||||
# modify info before adding into the buffer
|
||||
if kwargs.get('info', None) is not None:
|
||||
@ -70,28 +66,30 @@ def test_collector():
|
||||
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||
preprocess_fn)
|
||||
c0.collect(n_step=3, log_fn=logger.log)
|
||||
assert equal(c0.buffer.obs[:3], [0, 1, 0])
|
||||
assert equal(c0.buffer[:3].obs_next, [1, 2, 1])
|
||||
assert np.allclose(c0.buffer.obs[:3], [0, 1, 0])
|
||||
assert np.allclose(c0.buffer[:3].obs_next, [1, 2, 1])
|
||||
c0.collect(n_episode=3, log_fn=logger.log)
|
||||
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
||||
assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
|
||||
assert np.allclose(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
||||
assert np.allclose(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
|
||||
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||
preprocess_fn)
|
||||
c1.collect(n_step=6)
|
||||
assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
|
||||
assert equal(c1.buffer[:11].obs_next, [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
|
||||
assert np.allclose(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
|
||||
assert np.allclose(c1.buffer[:11].obs_next,
|
||||
[1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
|
||||
c1.collect(n_episode=2)
|
||||
assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
|
||||
assert equal(c1.buffer[11:21].obs_next, [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
|
||||
assert np.allclose(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
|
||||
assert np.allclose(c1.buffer[11:21].obs_next,
|
||||
[1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
|
||||
c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||
preprocess_fn)
|
||||
c2.collect(n_episode=[1, 2, 2, 2])
|
||||
assert equal(c2.buffer.obs_next[:26], [
|
||||
assert np.allclose(c2.buffer.obs_next[:26], [
|
||||
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
||||
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
||||
c2.reset_env()
|
||||
c2.collect(n_episode=[2, 2, 2, 2])
|
||||
assert equal(c2.buffer.obs_next[26:54], [
|
||||
assert np.allclose(c2.buffer.obs_next[26:54], [
|
||||
1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
||||
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
||||
|
||||
@ -115,7 +113,7 @@ def test_collector_with_dict_state():
|
||||
batch = c1.sample(10)
|
||||
print(batch)
|
||||
c0.buffer.update(c1.buffer)
|
||||
assert equal(c0.buffer[:len(c0.buffer)].obs.index, [
|
||||
assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index, [
|
||||
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
|
||||
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
|
||||
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.])
|
||||
|
@ -47,7 +47,7 @@ class ReplayBuffer(object):
|
||||
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
|
||||
(typically for RNN usage, see issue#19), ignoring storing the next
|
||||
observation (save memory in atari tasks), and multi-modal observation (see
|
||||
issue#38, need version >= 0.2.3):
|
||||
issue#38):
|
||||
::
|
||||
|
||||
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
|
||||
@ -250,7 +250,7 @@ class ReplayBuffer(object):
|
||||
# set last frame done to True
|
||||
last_index = (self._index - 1 + self._size) % self._size
|
||||
last_done, self.done[last_index] = self.done[last_index], True
|
||||
if key == 'obs_next' and not self._save_s_:
|
||||
if key == 'obs_next' and (not self._save_s_ or self.obs_next is None):
|
||||
indice += 1 - self.done[indice].astype(np.int)
|
||||
indice[indice == self._size] = 0
|
||||
key = 'obs'
|
||||
|
@ -26,13 +26,19 @@ class BasePolicy(ABC, nn.Module):
|
||||
Most of the policy needs a neural network to predict the action and an
|
||||
optimizer to optimize the policy. The rules of self-defined networks are:
|
||||
|
||||
1. Input: observation ``obs`` (may be a ``numpy.ndarray`` or \
|
||||
``torch.Tensor``), hidden state ``state`` (for RNN usage), and other \
|
||||
information ``info`` provided by the environment.
|
||||
2. Output: some ``logits`` and the next hidden state ``state``. The logits\
|
||||
could be a tuple instead of a ``torch.Tensor``. It depends on how the \
|
||||
policy process the network output. For example, in PPO, the return of \
|
||||
the network might be ``(mu, sigma), state`` for Gaussian policy.
|
||||
1. Input: observation ``obs`` (may be a ``numpy.ndarray``, a \
|
||||
``torch.Tensor``, a dict or any others), hidden state ``state`` (for \
|
||||
RNN usage), and other information ``info`` provided by the \
|
||||
environment.
|
||||
2. Output: some ``logits``, the next hidden state ``state``, and the \
|
||||
intermediate result during policy forwarding procedure ``policy``. The\
|
||||
``logits`` could be a tuple instead of a ``torch.Tensor``. It depends \
|
||||
on how the policy process the network output. For example, in PPO, the\
|
||||
return of the network might be ``(mu, sigma), state`` for Gaussian \
|
||||
policy. The ``policy`` can be a Batch of torch.Tensor or other things,\
|
||||
which will be stored in the replay buffer, and can be accessed in the \
|
||||
policy update process (e.g. in ``policy.learn()``, the \
|
||||
``batch.policy`` is what you need).
|
||||
|
||||
Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``,
|
||||
you can use :class:`~tianshou.policy.BasePolicy` almost the same as
|
||||
|
Loading…
x
Reference in New Issue
Block a user