Tianshou/docs/tutorials/cheatsheet.rst
Markus Krimmel 6c6c872523
Gymnasium Integration (#789)
Changes:
- Disclaimer in README
- Replaced all occurences of Gym with Gymnasium
- Removed code that is now dead since we no longer need to support the
old step API
- Updated type hints to only allow new step API
- Increased required version of envpool to support Gymnasium
- Increased required version of PettingZoo to support Gymnasium
- Updated `PettingZooEnv` to only use the new step API, removed hack to
also support old API
- I had to add some `# type: ignore` comments, due to new type hinting
in Gymnasium. I'm not that familiar with type hinting but I believe that
the issue is on the Gymnasium side and we are looking into it.
- Had to update `MyTestEnv` to support `options` kwarg
- Skip NNI tests because they still use OpenAI Gym
- Also allow `PettingZooEnv` in vector environment
- Updated doc page about ReplayBuffer to also talk about terminated and
truncated flags.

Still need to do: 
- Update the Jupyter notebooks in docs
- Check the entire code base for more dead code (from compatibility
stuff)
- Check the reset functions of all environments/wrappers in code base to
make sure they use the `options` kwarg
- Someone might want to check test_env_finite.py
- Is it okay to allow `PettingZooEnv` in vector environments? Might need
to update docs?
2023-02-03 11:57:27 -08:00

482 lines
21 KiB
ReStructuredText

Cheat Sheet
===========
This page shows some code snippets of how to use Tianshou to develop new
algorithms / apply algorithms to new scenarios.
By the way, some of these issues can be resolved by using a ``gymnasium.Wrapper``.
It could be a universal solution in the policy-environment interaction. But
you can also use the batch processor :ref:`preprocess_fn` or vectorized
environment wrapper :class:`~tianshou.env.VectorEnvWrapper`.
.. _network_api:
Build Policy Network
--------------------
See :ref:`build_the_network`.
.. _new_policy:
Build New Policy
----------------
See :class:`~tianshou.policy.BasePolicy`.
.. _eval_policy:
Manually Evaluate Policy
------------------------
If you'd like to manually see the action generated by a well-trained agent:
::
# assume obs is a single environment observation
action = policy(Batch(obs=np.array([obs]))).act[0]
.. _customize_training:
Customize Training Process
--------------------------
See :ref:`customized_trainer`.
.. _resume_training:
Resume Training Process
-----------------------
This is related to `Issue 349 <https://github.com/thu-ml/tianshou/issues/349>`_.
To resume training process from an existing checkpoint, you need to do the following things in the training process:
1. Make sure you write ``save_checkpoint_fn`` which saves everything needed in the training process, i.e., policy, optim, buffer; pass it to trainer;
2. Use ``TensorboardLogger``;
3. To adjust the save frequency, specify ``save_interval`` when initializing TensorboardLogger.
And to successfully resume from a checkpoint:
1. Load everything needed in the training process **before trainer initialization**, i.e., policy, optim, buffer;
2. Set ``resume_from_log=True`` with trainer;
We provide an example to show how these steps work: checkout `test_c51.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_c51.py>`_, `test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/continuous/test_ppo.py>`_ or `test_discrete_bcq.py <https://github.com/thu-ml/tianshou/blob/master/test/offline/test_discrete_bcq.py>`_ by running
.. code-block:: console
$ python3 test/discrete/test_c51.py # train some epoch
$ python3 test/discrete/test_c51.py --resume # restore from existing log and continuing training
To correctly render the data (including several tfevent files), we highly recommend using ``tensorboard >= 2.5.0`` (see `here <https://github.com/thu-ml/tianshou/pull/350#issuecomment-829123378>`_ for the reason). Otherwise, it may cause overlapping issue that you need to manually handle with.
.. _parallel_sampling:
Parallel Sampling
-----------------
Tianshou provides the following classes for vectorized environment:
- :class:`~tianshou.env.DummyVectorEnv` is for pseudo-parallel simulation (implemented with a for-loop, useful for debugging).
- :class:`~tianshou.env.SubprocVectorEnv` uses multiple processes for parallel simulation. This is the most often choice for parallel simulation.
- :class:`~tianshou.env.ShmemVectorEnv` has a similar implementation to :class:`~tianshou.env.SubprocVectorEnv`, but is optimized (in terms of both memory footprint and simulation speed) for environments with large observations such as images.
- :class:`~tianshou.env.RayVectorEnv` is currently the only choice for parallel simulation in a cluster with multiple machines.
Although these classes are optimized for different scenarios, they have exactly the same APIs because they are sub-classes of :class:`~tianshou.env.BaseVectorEnv`. Just provide a list of functions who return environments upon called, and it is all set.
::
env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]]
venv = SubprocVectorEnv(env_fns) # DummyVectorEnv, ShmemVectorEnv, or RayVectorEnv, whichever you like.
venv.reset() # returns the initial observations of each environment
venv.step(actions) # provide actions for each environment and get their results
.. sidebar:: An example of sync/async VectorEnv (steps with the same color end up in one batch that is disposed by the policy at the same time).
.. Figure:: ../_static/images/async.png
By default, parallel environment simulation is synchronous: a step is done after all environments have finished a step. Synchronous simulation works well if each step of environments costs roughly the same time.
In case the time cost of environments varies a lot (e.g. 90% step cost 1s, but 10% cost 10s) where slow environments lag fast environments behind, async simulation can be used (related to `Issue 103 <https://github.com/thu-ml/tianshou/issues/103>`_). The idea is to start those finished environments without waiting for slow environments.
Asynchronous simulation is a built-in functionality of
:class:`~tianshou.env.BaseVectorEnv`. Just provide ``wait_num`` or ``timeout``
(or both) and async simulation works.
::
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=x) for i in [2, 3, 4, 5]]
# DummyVectorEnv, ShmemVectorEnv, or RayVectorEnv, whichever you like.
venv = SubprocVectorEnv(env_fns, wait_num=3, timeout=0.2)
venv.reset() # returns the initial observations of each environment
# returns "wait_num" steps or finished steps after "timeout" seconds,
# whichever occurs first.
venv.step(actions, ready_id)
If we have 4 envs and set ``wait_num = 3``, each of the step only returns 3 results of these 4 envs.
You can treat the ``timeout`` parameter as a dynamic ``wait_num``. In each vectorized step it only returns the environments finished within the given time. If there is no such environment, it will wait until any of them finished.
The figure in the right gives an intuitive comparison among synchronous/asynchronous simulation.
.. note::
The async simulation collector would cause some exceptions when used as
``test_collector`` in :doc:`/api/tianshou.trainer` (related to
`Issue 700 <https://github.com/thu-ml/tianshou/issues/700>`_). Please use
sync version for ``test_collector`` instead.
.. warning::
If you use your own environment, please make sure the ``seed`` method is set up properly, e.g.,
::
def seed(self, seed):
np.random.seed(seed)
Otherwise, the outputs of these envs may be the same with each other.
.. _envpool_integration:
EnvPool Integration
-------------------
`EnvPool <https://github.com/sail-sg/envpool/>`_ is a C++-based vectorized environment implementation and is way faster than the above solutions. The APIs are almost the same as above four classes, so that means you can directly switch the vectorized environment to envpool and get immediate speed-up.
Currently it supports
`Atari <https://github.com/thu-ml/tianshou/tree/master/examples/atari#envpool>`_,
`Mujoco <https://github.com/thu-ml/tianshou/tree/master/examples/mujoco#envpool>`_,
`VizDoom <https://github.com/thu-ml/tianshou/tree/master/examples/vizdoom#envpool>`_,
toy_text and classic_control environments. For more information, please refer to `EnvPool's documentation <https://envpool.readthedocs.io/en/latest/>`_.
::
# install envpool: pip3 install envpool
import envpool
envs = envpool.make_gymnasium("CartPole-v0", num_envs=10)
collector = Collector(policy, envs, buffer)
Here are some other `examples <https://github.com/sail-sg/envpool/tree/master/examples/tianshou_examples>`_.
.. _preprocess_fn:
Handle Batched Data Stream in Collector
---------------------------------------
This is related to `Issue 42 <https://github.com/thu-ml/tianshou/issues/42>`_.
If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer.
It will receive with "obs" and "env_id" when the collector resets the environment, and will receive six keys "obs_next", "rew", "done", "info", "policy", "env_id" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values.
These variables are intended to gather all the information requires to keep track of a simulation step, namely the (observation, action, reward, done flag, next observation, info, intermediate result of the policy) at time t, for the whole duration of the simulation.
For example, you can write your hook as:
::
import numpy as np
from collections import deque
class MyProcessor:
def __init__(self, size=100):
self.episode_log = None
self.main_log = deque(maxlen=size)
self.main_log.append(0)
self.baseline = 0
def preprocess_fn(**kwargs):
"""change reward to zero mean"""
# if obs && env_id exist -> reset
# if obs_next/act/rew/done/policy/env_id exist -> normal step
if 'rew' not in kwargs:
# means that it is called after env.reset(), it can only process the obs
return Batch() # none of the variables are needed to be updated
else:
n = len(kwargs['rew']) # the number of envs in collector
if self.episode_log is None:
self.episode_log = [[] for i in range(n)]
for i in range(n):
self.episode_log[i].append(kwargs['rew'][i])
kwargs['rew'][i] -= self.baseline
for i in range(n):
if kwargs['done'][i]:
self.main_log.append(np.mean(self.episode_log[i]))
self.episode_log[i] = []
self.baseline = np.mean(self.main_log)
return Batch(rew=kwargs['rew'])
And finally,
::
test_processor = MyProcessor(size=100)
collector = Collector(policy, env, buffer, preprocess_fn=test_processor.preprocess_fn)
Some examples are in `test/base/test_collector.py <https://github.com/thu-ml/tianshou/blob/master/test/base/test_collector.py>`_.
Another solution is to create a vector environment wrapper through :class:`~tianshou.env.VectorEnvWrapper`, e.g.
::
import numpy as np
from collections import deque
from tianshou.env import VectorEnvWrapper
class MyWrapper(VectorEnvWrapper):
def __init__(self, venv, size=100):
self.episode_log = None
self.main_log = deque(maxlen=size)
self.main_log.append(0)
self.baseline = 0
def step(self, action, env_id):
obs, rew, done, info = self.venv.step(action, env_id)
n = len(rew)
if self.episode_log is None:
self.episode_log = [[] for i in range(n)]
for i in range(n):
self.episode_log[i].append(rew[i])
rew[i] -= self.baseline
for i in range(n):
if done[i]:
self.main_log.append(np.mean(self.episode_log[i]))
self.episode_log[i] = []
self.baseline = np.mean(self.main_log)
return obs, rew, done, info
env = MyWrapper(env, size=100)
collector = Collector(policy, env, buffer)
We provide an observation normalization vector env wrapper: :class:`~tianshou.env.VectorEnvNormObs`.
.. _rnn_training:
RNN-style Training
------------------
This is related to `Issue 19 <https://github.com/thu-ml/tianshou/issues/19>`_.
First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`, :class:`~tianshou.data.VectorReplayBuffer`, or other types of buffer you are using, like:
::
buf = ReplayBuffer(size=size, stack_num=stack_num)
Then, change the network to recurrent-style, for example, :class:`~tianshou.utils.net.common.Recurrent`, :class:`~tianshou.utils.net.continuous.RecurrentActorProb` and :class:`~tianshou.utils.net.continuous.RecurrentCritic`.
The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.Wrapper`` to modify the state representation. For example, if we add a wrapper that map ``[s, a]`` pair to a new state:
- Before: ``(s, a, s', r, d)`` stored in replay buffer, and get stacked s;
- After applying wrapper: ``([s, a], a, [s', a'], r, d)`` stored in replay buffer, and get both stacked s and a.
.. _multi_gpu:
Multi-GPU Training
------------------
To enable training an RL agent with multiple GPUs for a standard environment (i.e., without nested observation) with default networks provided by Tianshou:
1. Import :class:`~tianshou.utils.net.common.DataParallelNet` from ``tianshou.utils.net.common``;
2. Change the ``device`` argument to ``None`` in the existing networks such as ``Net``, ``Actor``, ``Critic``, ``ActorProb``
3. Apply ``DataParallelNet`` wrapper to these networks.
::
from tianshou.utils.net.common import Net, DataParallelNet
from tianshou.utils.net.discrete import Actor, Critic
actor = DataParallelNet(Actor(net, args.action_shape, device=None).to(args.device))
critic = DataParallelNet(Critic(net, device=None).to(args.device))
Yes, that's all! This general approach can be applied to almost all kinds of algorithms implemented in Tianshou.
We provide a complete script to show how to run multi-GPU: `test/discrete/test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_ppo.py>`_
As for other cases such as customized network or environments that have a nested observation, here are the rules:
1. The data format transformation (numpy -> cuda) is done in the ``DataParallelNet`` wrapper; your customized network should not apply any kinds of data format transformation;
2. Create a similar class that inherit ``DataParallelNet``, which is only in charge of data format transformation (numpy -> cuda);
3. Do the same things above.
.. _self_defined_env:
User-defined Environment and Different State Representation
-----------------------------------------------------------
This is related to `Issue 38 <https://github.com/thu-ml/tianshou/issues/38>`_ and `Issue 69 <https://github.com/thu-ml/tianshou/issues/69>`_.
First of all, your self-defined environment must follow the Gym's API, some of them are listed below:
- reset() -> state
- step(action) -> state, reward, done, info
- seed(s) -> List[int]
- render(mode) -> Any
- close() -> None
- observation_space: gym.Space
- action_space: gym.Space
The state can be a ``numpy.ndarray`` or a Python dictionary. Take "FetchReach-v1" as an example:
::
>>> e = gym.make('FetchReach-v1')
>>> e.reset()
{'observation': array([ 1.34183265e+00, 7.49100387e-01, 5.34722720e-01, 1.97805133e-04,
7.15193042e-05, 7.73933014e-06, 5.51992816e-08, -2.42927453e-06,
4.73325650e-06, -2.28455228e-06]),
'achieved_goal': array([1.34183265, 0.74910039, 0.53472272]),
'desired_goal': array([1.24073906, 0.77753463, 0.63457791])}
It shows that the state is a dictionary which has 3 keys. It will stored in :class:`~tianshou.data.ReplayBuffer` as:
::
>>> from tianshou.data import Batch, ReplayBuffer
>>> b = ReplayBuffer(size=3)
>>> b.add(Batch(obs=e.reset(), act=0, rew=0, done=0))
>>> print(b)
ReplayBuffer(
act: array([0, 0, 0]),
done: array([False, False, False]),
obs: Batch(
achieved_goal: array([[1.34183265, 0.74910039, 0.53472272],
[0. , 0. , 0. ],
[0. , 0. , 0. ]]),
desired_goal: array([[1.42154265, 0.62505137, 0.62929863],
[0. , 0. , 0. ],
[0. , 0. , 0. ]]),
observation: array([[ 1.34183265e+00, 7.49100387e-01, 5.34722720e-01,
1.97805133e-04, 7.15193042e-05, 7.73933014e-06,
5.51992816e-08, -2.42927453e-06, 4.73325650e-06,
-2.28455228e-06],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00]]),
),
rew: array([0, 0, 0]),
)
>>> print(b.obs.achieved_goal)
[[1.34183265 0.74910039 0.53472272]
[0. 0. 0. ]
[0. 0. 0. ]]
And the data batch sampled from this replay buffer:
::
>>> batch, indices = b.sample(2)
>>> batch.keys()
['act', 'done', 'info', 'obs', 'obs_next', 'policy', 'rew']
>>> batch.obs[-1]
Batch(
achieved_goal: array([1.34183265, 0.74910039, 0.53472272]),
desired_goal: array([1.42154265, 0.62505137, 0.62929863]),
observation: array([ 1.34183265e+00, 7.49100387e-01, 5.34722720e-01, 1.97805133e-04,
7.15193042e-05, 7.73933014e-06, 5.51992816e-08, -2.42927453e-06,
4.73325650e-06, -2.28455228e-06]),
)
>>> batch.obs.desired_goal[-1] # recommended
array([1.42154265, 0.62505137, 0.62929863])
>>> batch.obs[-1].desired_goal # not recommended
array([1.42154265, 0.62505137, 0.62929863])
>>> batch[-1].obs.desired_goal # not recommended
array([1.42154265, 0.62505137, 0.62929863])
Thus, in your self-defined network, just change the ``forward`` function as:
::
def forward(self, s, ...):
# s is a batch
observation = s.observation
achieved_goal = s.achieved_goal
desired_goal = s.desired_goal
...
For self-defined class, the replay buffer will store the reference into a ``numpy.ndarray``, e.g.:
::
>>> import networkx as nx
>>> b = ReplayBuffer(size=3)
>>> b.add(Batch(obs=nx.Graph(), act=0, rew=0, done=0))
>>> print(b)
ReplayBuffer(
act: array([0, 0, 0]),
done: array([0, 0, 0]),
info: Batch(),
obs: array([<networkx.classes.graph.Graph object at 0x7f5c607826a0>, None,
None], dtype=object),
policy: Batch(),
rew: array([0, 0, 0]),
)
But the state stored in the buffer may be a shallow-copy. To make sure each of your state stored in the buffer is distinct, please return the deep-copy version of your state in your env:
::
def reset():
return copy.deepcopy(self.graph)
def step(action):
...
return copy.deepcopy(self.graph), reward, done, {}
.. note ::
Please make sure this variable is numpy-compatible, e.g., np.array([variable]) will not result in an empty array. Otherwise, ReplayBuffer cannot create an numpy array to store it.
.. _marl_example:
Multi-Agent Reinforcement Learning
----------------------------------
This is related to `Issue 121 <https://github.com/thu-ml/tianshou/issues/121>`_. The discussion is still goes on.
With the flexible core APIs, Tianshou can support multi-agent reinforcement learning with minimal efforts.
Currently, we support three types of multi-agent reinforcement learning paradigms:
1. Simultaneous move: at each timestep, all the agents take their actions (example: MOBA games)
2. Cyclic move: players take action in turn (example: Go game)
3. Conditional move, at each timestep, the environment conditionally selects an agent to take action. (example: `Pig Game <https://en.wikipedia.org/wiki/Pig_(dice_game)>`_)
We mainly address these multi-agent RL problems by converting them into traditional RL formulations.
For simultaneous move, the solution is simple: we can just add a ``num_agent`` dimension to state, action, and reward. Nothing else is going to change.
For 2 & 3 (cyclic move and conditional move), they can be unified into a single framework: at each timestep, the environment selects an agent with id ``agent_id`` to play. Since multi-agents are usually wrapped into one object (which we call "abstract agent"), we can pass the ``agent_id`` to the "abstract agent", leaving it to further call the specific agent.
In addition, legal actions in multi-agent RL often vary with timestep (just like Go games), so the environment should also passes the legal action mask to the "abstract agent", where the mask is a boolean array that "True" for available actions and "False" for illegal actions at the current step. Below is a figure that explains the abstract agent.
.. image:: /_static/images/marl.png
:align: center
:height: 300
The above description gives rise to the following formulation of multi-agent RL:
::
act = policy(state, agent_id, mask)
(next_state, next_agent_id, next_mask), reward = env.step(act)
By constructing a new state ``state_ = (state, agent_id, mask)``, essentially we can return to the typical formulation of RL:
::
act = policy(state_)
next_state_, reward = env.step(act)
Following this idea, we write a tiny example of playing `Tic Tac Toe <https://en.wikipedia.org/wiki/Tic-tac-toe>`_ against a random player by using a Q-learning algorithm. The tutorial is at :doc:`/tutorials/tictactoe`.