diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 1a89645..61e1352 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -1,6 +1,192 @@ Cheat Sheet =========== -This page shows some code snippets of how to use Tianshou to develop new algorithms. +This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios. -TODO +By the way, some of these issues can be resolved by using a ``gym.wrapper``. It can be a solution in the policy-environment interaction. + +.. _network_api: + +Build Policy Network +-------------------- + +See :ref:`build_the_network`. + +.. _new_policy: + +Build New Policy +---------------- + +See :class:`~tianshou.policy.BasePolicy`. + +.. _parallel_sampling: + +Customize Training Process +-------------------------- + +See :ref:`customized_trainer`. + +Parallel Sampling +----------------- + +Use :class:`~tianshou.env.VectorEnv` or :class:`~tianshou.env.SubprocVectorEnv`. +:: + + env_fns = [ + lambda: MyTestEnv(size=2), + lambda: MyTestEnv(size=3), + lambda: MyTestEnv(size=4), + lambda: MyTestEnv(size=5), + ] + venv = SubprocVectorEnv(env_fns) + +where ``env_fns`` is a list of callable env hooker. The above code can be written in for-loop as well: +:: + + env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]] + venv = SubprocVectorEnv(env_fns) + +.. _rnn_training: + +RNN-style Training +------------------ + +This is related to `Issue 19 `_. + +First, add an argument ``stack_num`` to :class:`~tianshou.data.ReplayBuffer`: +:: + + buf = ReplayBuffer(size=size, stack_num=stack_num) + +Then, change the network to recurrent-style, for example, class ``Recurrent`` in `code snippet 1 `_, or ``RecurrentActor`` and ``RecurrentCritic`` in `code snippet 2 `_. + +The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.wrapper`` to modify the state representation. For example, if we add a wrapper that map [s, a] pair to a new state: + +- Before: (s, a, s', r, d) stored in replay buffer, and get stacked s; +- After applying wrapper: ([s, a], a, [s', a'], r, d) stored in replay buffer, and get both stacked s and a. + +.. _self_defined_env: + +User-defined Environment and Different State Representation +----------------------------------------------------------- + +This is related to `Issue 38 `_ and `Issue 69 `_. + +First of all, your self-defined environment must follow the Gym's API, some of them are listed below: + +- reset() -> state + +- step(action) -> state, reward, done, info + +- seed(s) -> None + +- render(mode) -> None + +- close() -> None + +The state can be a ``numpy.ndarray`` or a Python dictionary. Take ``FetchReach-v1`` as an example: +:: + + >>> e = gym.make('FetchReach-v1') + >>> e.reset() + {'observation': array([ 1.34183265e+00, 7.49100387e-01, 5.34722720e-01, 1.97805133e-04, + 7.15193042e-05, 7.73933014e-06, 5.51992816e-08, -2.42927453e-06, + 4.73325650e-06, -2.28455228e-06]), + 'achieved_goal': array([1.34183265, 0.74910039, 0.53472272]), + 'desired_goal': array([1.24073906, 0.77753463, 0.63457791])} + +It shows that the state is a dictionary which has 3 keys. It will stored in :class:`~tianshou.data.ReplayBuffer` as: +:: + + >>> from tianshou.data import ReplayBuffer + >>> b = ReplayBuffer(size=3) + >>> b.add(obs=e.reset(), act=0, rew=0, done=0) + >>> print(b) + ReplayBuffer( + act: array([0, 0, 0]), + done: array([0, 0, 0]), + info: Batch(), + obs: Batch( + achieved_goal: array([[1.34183265, 0.74910039, 0.53472272], + [0. , 0. , 0. ], + [0. , 0. , 0. ]]), + desired_goal: array([[1.42154265, 0.62505137, 0.62929863], + [0. , 0. , 0. ], + [0. , 0. , 0. ]]), + observation: array([[ 1.34183265e+00, 7.49100387e-01, 5.34722720e-01, + 1.97805133e-04, 7.15193042e-05, 7.73933014e-06, + 5.51992816e-08, -2.42927453e-06, 4.73325650e-06, + -2.28455228e-06], + [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, + 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, + 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, + 0.00000000e+00], + [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, + 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, + 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, + 0.00000000e+00]]), + ), + policy: Batch(), + rew: array([0, 0, 0]), + ) + >>> print(b.obs.achieved_goal) + [[1.34183265 0.74910039 0.53472272] + [0. 0. 0. ] + [0. 0. 0. ]] + +And the data batch sampled from this replay buffer: +:: + + >>> batch, indice = b.sample(2) + >>> batch.keys() + ['act', 'done', 'info', 'obs', 'obs_next', 'policy', 'rew'] + >>> batch.obs[-1] + Batch( + achieved_goal: array([1.34183265, 0.74910039, 0.53472272]), + desired_goal: array([1.42154265, 0.62505137, 0.62929863]), + observation: array([ 1.34183265e+00, 7.49100387e-01, 5.34722720e-01, 1.97805133e-04, + 7.15193042e-05, 7.73933014e-06, 5.51992816e-08, -2.42927453e-06, + 4.73325650e-06, -2.28455228e-06]), + ) + >>> batch.obs.desired_goal[-1] # recommended + array([1.42154265, 0.62505137, 0.62929863]) + >>> batch.obs[-1].desired_goal # not recommended + array([1.42154265, 0.62505137, 0.62929863]) + >>> batch[-1].obs.desired_goal # not recommended + array([1.42154265, 0.62505137, 0.62929863]) + +Thus, in your self-defined network, just change the ``forward`` function as: +:: + + def forward(self, s, ...): + # s is a batch + observation = s.observation + achieved_goal = s.achieved_goal + desired_goal = s.desired_goal + ... + +For self-defined class, the replay buffer will store the reference into a ``numpy.ndarray``, e.g.: +:: + + >>> import networkx as nx + >>> b = ReplayBuffer(size=3) + >>> b.add(obs=nx.Graph(), act=0, rew=0, done=0) + >>> print(b) + ReplayBuffer( + act: array([0, 0, 0]), + done: array([0, 0, 0]), + info: Batch(), + obs: array([, None, + None], dtype=object), + policy: Batch(), + rew: array([0, 0, 0]), + ) + +But the state stored in the buffer may be a shallow-copy. To make sure each of your state stored in the buffer is distinct, please return the deep-copy version of your state in your env: +:: + + def reset(): + return copy.deepcopy(self.graph) + def step(a): + ... + return copy.deepcopy(self.graph), reward, done, {} diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index a5d97c8..0d43031 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -42,6 +42,7 @@ Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_ For the demonstration, here we use the second block of codes. +.. _build_the_network: Build the Network ----------------- @@ -75,8 +76,8 @@ Tianshou supports any user-defined PyTorch networks and optimizers but with the The rules of self-defined networks are: -1. Input: observation ``obs`` (may be a ``numpy.ndarray`` or ``torch.Tensor``), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment. -2. Output: some ``logits`` and the next hidden state ``state``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. +1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment. +2. Output: some ``logits``, the next hidden state ``state``, and intermediate result during the policy forwarding procedure ``policy``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. The ``policy`` can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in ``policy.learn()``, the ``batch.policy`` is what you need). Setup Policy diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index af6fbdc..6474d44 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -17,7 +17,7 @@ def test_replaybuffer(size=10, bufsize=20): obs_next, rew, done, info = env.step(a) buf.add(obs, a, rew, done, obs_next, info) obs = obs_next - assert len(buf) == min(bufsize, i + 1), print(len(buf), i) + assert len(buf) == min(bufsize, i + 1) data, indice = buf.sample(bufsize * 2) assert (indice < len(buf)).all() assert (data.obs < size).all() @@ -40,10 +40,10 @@ def test_stack(size=5, bufsize=9, stack_num=4): if done: obs = env.reset(1) indice = np.arange(len(buf)) - assert abs(buf.get(indice, 'obs') - np.array([ + assert np.allclose(buf.get(indice, 'obs'), np.array([ [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], - [3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])).sum() < 1e-6 + [3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])) print(buf) @@ -63,7 +63,7 @@ def test_priortized_replaybuffer(size=32, bufsize=15): assert len(data) == len(buf) else: assert len(data) == len(buf) // 2 - assert len(buf) == min(bufsize, i + 1), print(len(buf), i) + assert len(buf) == min(bufsize, i + 1) assert np.isclose(buf._weight_sum, (buf.weight).sum()) data, indice = buf.sample(len(buf) // 2) buf.update_weight(indice, -data.weight / 2) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 9e40db2..95d8da1 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -25,10 +25,6 @@ class MyPolicy(BasePolicy): pass -def equal(a, b): - return abs(np.array(a) - np.array(b)).sum() < 1e-6 - - def preprocess_fn(**kwargs): # modify info before adding into the buffer if kwargs.get('info', None) is not None: @@ -70,28 +66,30 @@ def test_collector(): c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False), preprocess_fn) c0.collect(n_step=3, log_fn=logger.log) - assert equal(c0.buffer.obs[:3], [0, 1, 0]) - assert equal(c0.buffer[:3].obs_next, [1, 2, 1]) + assert np.allclose(c0.buffer.obs[:3], [0, 1, 0]) + assert np.allclose(c0.buffer[:3].obs_next, [1, 2, 1]) c0.collect(n_episode=3, log_fn=logger.log) - assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1]) - assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) + assert np.allclose(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1]) + assert np.allclose(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False), preprocess_fn) c1.collect(n_step=6) - assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3]) - assert equal(c1.buffer[:11].obs_next, [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4]) + assert np.allclose(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3]) + assert np.allclose(c1.buffer[:11].obs_next, + [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4]) c1.collect(n_episode=2) - assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2]) - assert equal(c1.buffer[11:21].obs_next, [1, 2, 3, 4, 5, 1, 2, 1, 2, 3]) + assert np.allclose(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2]) + assert np.allclose(c1.buffer[11:21].obs_next, + [1, 2, 3, 4, 5, 1, 2, 1, 2, 3]) c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False), preprocess_fn) c2.collect(n_episode=[1, 2, 2, 2]) - assert equal(c2.buffer.obs_next[:26], [ + assert np.allclose(c2.buffer.obs_next[:26], [ 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) c2.reset_env() c2.collect(n_episode=[2, 2, 2, 2]) - assert equal(c2.buffer.obs_next[26:54], [ + assert np.allclose(c2.buffer.obs_next[26:54], [ 1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) @@ -115,7 +113,7 @@ def test_collector_with_dict_state(): batch = c1.sample(10) print(batch) c0.buffer.update(c1.buffer) - assert equal(c0.buffer[:len(c0.buffer)].obs.index, [ + assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index, [ 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 3ab4ed1..92e34a6 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -47,7 +47,7 @@ class ReplayBuffer(object): :class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in atari tasks), and multi-modal observation (see - issue#38, need version >= 0.2.3): + issue#38): :: >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) @@ -250,7 +250,7 @@ class ReplayBuffer(object): # set last frame done to True last_index = (self._index - 1 + self._size) % self._size last_done, self.done[last_index] = self.done[last_index], True - if key == 'obs_next' and not self._save_s_: + if key == 'obs_next' and (not self._save_s_ or self.obs_next is None): indice += 1 - self.done[indice].astype(np.int) indice[indice == self._size] = 0 key = 'obs' diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index dba4357..e7fa474 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -26,13 +26,19 @@ class BasePolicy(ABC, nn.Module): Most of the policy needs a neural network to predict the action and an optimizer to optimize the policy. The rules of self-defined networks are: - 1. Input: observation ``obs`` (may be a ``numpy.ndarray`` or \ - ``torch.Tensor``), hidden state ``state`` (for RNN usage), and other \ - information ``info`` provided by the environment. - 2. Output: some ``logits`` and the next hidden state ``state``. The logits\ - could be a tuple instead of a ``torch.Tensor``. It depends on how the \ - policy process the network output. For example, in PPO, the return of \ - the network might be ``(mu, sigma), state`` for Gaussian policy. + 1. Input: observation ``obs`` (may be a ``numpy.ndarray``, a \ + ``torch.Tensor``, a dict or any others), hidden state ``state`` (for \ + RNN usage), and other information ``info`` provided by the \ + environment. + 2. Output: some ``logits``, the next hidden state ``state``, and the \ + intermediate result during policy forwarding procedure ``policy``. The\ + ``logits`` could be a tuple instead of a ``torch.Tensor``. It depends \ + on how the policy process the network output. For example, in PPO, the\ + return of the network might be ``(mu, sigma), state`` for Gaussian \ + policy. The ``policy`` can be a Batch of torch.Tensor or other things,\ + which will be stored in the replay buffer, and can be accessed in the \ + policy update process (e.g. in ``policy.learn()``, the \ + ``batch.policy`` is what you need). Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``, you can use :class:`~tianshou.policy.BasePolicy` almost the same as