Michael Panchenko 07702fc007
Improved typing and reduced duplication (#912)
# Goals of the PR

The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.

## Major Change 1 - BatchProtocol

**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.

**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like

```python
class ActionBatchProtocol(BatchProtocol):
    logits: Sequence[Union[tuple, torch.Tensor]]
    dist: torch.distributions.Distribution
    act: torch.Tensor
    state: Optional[torch.Tensor]


class RolloutBatchProtocol(BatchProtocol):
    obs: torch.Tensor
    obs_next: torch.Tensor
    info: Dict[str, Any]
    rew: torch.Tensor
    terminated: torch.Tensor
    truncated: torch.Tensor

class PGPolicy(BasePolicy):
    ...

    def forward(
        self,
        batch: RolloutBatchProtocol,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        **kwargs: Any,
    ) -> ActionBatchProtocol:

```

The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.

## Major Change 2 - remove duplication in trainer package

**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.

**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.

1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.

I have some more reasons, but maybe the above ones are convincing
enough.

## Minor changes: improved input validation and types

I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.

I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.

@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`

## Breaking Changes

The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them

---------

Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 09:54:46 -07:00
2020-06-10 12:06:56 +08:00


PyPI Conda Read the Docs Read the Docs Unittest codecov GitHub issues GitHub stars GitHub forks GitHub license

⚠️ Transition to Gymnasium: The maintainers of OpenAI Gym have recently released Gymnasium, which is where future maintenance of OpenAI Gym will be taking place. Tianshou has transitioned to internally using Gymnasium environments. You can still use OpenAI Gym environments with Tianshou vector environments, but they will be wrapped in a compatibility layer, which could be a source of issues. We recommend that you update your environment code to Gymnasium. If you want to continue using OpenAI Gym with Tianshou, you need to manually install Gym and Shimmy (the compatibility layer).

Tianshou (天授) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include:

Here are Tianshou's other features:

  • Elegant framework, using only ~4000 lines of code
  • State-of-the-art MuJoCo benchmark for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms
  • Support vectorized environment (synchronous or asynchronous) for all algorithms Usage
  • Support super-fast vectorized environment EnvPool for all algorithms Usage
  • Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) Usage
  • Support any type of environment state/action (e.g. a dict, a self-defined class, ...) Usage
  • Support customized training process Usage
  • Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
  • Support multi-agent RL Usage
  • Support both TensorBoard and W&B log tools
  • Support multi-GPU training Usage
  • Comprehensive documentation, PEP8 code-style checking, type checking and thorough tests

In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.

“天授”意指上天所授,引申为与生具有的天赋。天授是强化学习平台,而强化学习算法并不是向人类学习的,所以取“天授”意思是没有老师来教,而是自己通过跟环境不断交互来进行学习。

Installation

Tianshou is currently hosted on PyPI and conda-forge. It requires Python >= 3.8.

You can simply install Tianshou from PyPI with the following command:

$ pip install tianshou

If you use Anaconda or Miniconda, you can install Tianshou from conda-forge through the following command:

$ conda install tianshou -c conda-forge

You can also install with the newest version through GitHub:

$ pip install git+https://github.com/thu-ml/tianshou.git@master --upgrade

After installation, open your python console and type

import tianshou
print(tianshou.__version__)

If no error occurs, you have successfully installed Tianshou.

Documentation

The tutorials and API documentation are hosted on tianshou.readthedocs.io.

The example scripts are under test/ folder and examples/ folder.

中文文档位于 https://tianshou.readthedocs.io/zh/master/

Why Tianshou?

Comprehensive Functionality

RL Platform GitHub Stars # of Alg. (1) Custom Env Batch Training RNN Support Nested Observation Backend
Baselines GitHub stars 9 ✔️ (gym) (2) ✔️ TF1
Stable-Baselines GitHub stars 11 ✔️ (gym) (2) ✔️ TF1
Stable-Baselines3 GitHub stars 7 (3) ✔️ (gym) (2) ✔️ PyTorch
Ray/RLlib GitHub stars 16 ✔️ ✔️ ✔️ ✔️ TF/PyTorch
SpinningUp GitHub stars 6 ✔️ (gym) (2) PyTorch
Dopamine GitHub stars 7 TF/JAX
ACME GitHub stars 14 ✔️ (dm_env) ✔️ ✔️ ✔️ TF/JAX
keras-rl GitHub stars 7 ✔️ (gym) Keras
rlpyt GitHub stars 11 ✔️ ✔️ ✔️ PyTorch
ChainerRL GitHub stars 18 ✔️ (gym) ✔️ ✔️ Chainer
Sample Factory GitHub stars 1 (4) ✔️ (gym) ✔️ ✔️ ✔️ PyTorch
Tianshou GitHub stars 20 ✔️ (Gymnasium) ✔️ ✔️ ✔️ PyTorch

(1): access date: 2021-08-08

(2): not all algorithms support this feature

(3): TQC and QR-DQN in sb3-contrib instead of main repo

(4): super fast APPO!

High quality software engineering standard

RL Platform Documentation Code Coverage Type Hints Last Update
Baselines GitHub last commit
Stable-Baselines Documentation Status coverage GitHub last commit
Stable-Baselines3 Documentation Status coverage report ✔️ GitHub last commit
Ray/RLlib (1) ✔️ GitHub last commit
SpinningUp GitHub last commit
Dopamine GitHub last commit
ACME (1) ✔️ GitHub last commit
keras-rl Documentation (1) GitHub last commit
rlpyt Docs codecov GitHub last commit
ChainerRL Documentation Status Coverage Status GitHub last commit
Sample Factory codecov GitHub last commit
Tianshou Read the Docs codecov ✔️ GitHub last commit

(1): it has continuous integration but the coverage rate is not available

Reproducible and High Quality Result

Tianshou has its tests. Different from other platforms, the tests include the full agent training procedure for all of the implemented algorithms. It would be failed once if it could not train an agent to perform well enough on limited epochs on toy scenarios. The tests secure the reproducibility of our platform. Check out the GitHub Actions page for more detail.

The Atari/Mujoco benchmark results are under examples/atari/ and examples/mujoco/ folders. Our Mujoco result can beat most of existing benchmark.

Modularized Policy

We decouple all of the algorithms roughly into the following parts:

  • __init__: initialize the policy;
  • forward: to compute actions over given observations;
  • process_fn: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);
  • learn: to learn from a given batch data;
  • post_process_fn: to update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight);
  • update: the main interface for training, i.e., process_fn -> learn -> post_process_fn.

Within this API, we can interact with different policies conveniently.

Quick Start

This is an example of Deep Q Network. You can also run the full script at test/discrete/test_dqn.py.

First, import some relevant packages:

import gymnasium as gym
import torch, numpy as np, torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import tianshou as ts

Define some hyper-parameters:

task = 'CartPole-v0'
lr, epoch, batch_size = 1e-3, 10, 64
train_num, test_num = 10, 100
gamma, n_step, target_freq = 0.9, 3, 320
buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10
logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn'))  # TensorBoard is supported!
# For other loggers: https://tianshou.readthedocs.io/en/master/tutorials/logger.html

Make environments:

# you can also try with SubprocVectorEnv
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])

Define the network:

from tianshou.utils.net.common import Net
# you can define other net by following the API:
# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network
env = gym.make(task)
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
optim = torch.optim.Adam(net.parameters(), lr=lr)

Setup policy and collectors:

policy = ts.policy.DQNPolicy(net, optim, gamma, n_step, target_update_freq=target_freq)
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True)  # because DQN uses epsilon-greedy method

Let's train it:

result = ts.trainer.offpolicy_trainer(
    policy, train_collector, test_collector, epoch, step_per_epoch, step_per_collect,
    test_num, batch_size, update_per_step=1 / step_per_collect,
    train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
    test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
    stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
    logger=logger)
print(f'Finished training! Use {result["duration"]}')

Save / load the trained policy (it's exactly the same as PyTorch nn.module):

torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('dqn.pth'))

Watch the performance with 35 FPS:

policy.eval()
policy.set_eps(eps_test)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)

Look at the result saved in tensorboard: (with bash script in your terminal)

$ tensorboard --logdir log/dqn

You can check out the documentation for advanced usage.

It's worth a try: here is a test on a laptop (i7-8750H + GTX1060). It only uses 3 seconds for training an agent based on vanilla policy gradient on the CartPole-v0 task: (seed may be different across different platform and device)

$ python3 test/discrete/test_pg.py --seed 0 --render 0.03

Contributing

Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out this link.

Citing Tianshou

If you find Tianshou useful, please cite it in your publications.

@article{tianshou,
  author  = {Jiayi Weng and Huayu Chen and Dong Yan and Kaichao You and Alexis Duburcq and Minghao Zhang and Yi Su and Hang Su and Jun Zhu},
  title   = {Tianshou: A Highly Modularized Deep Reinforcement Learning Library},
  journal = {Journal of Machine Learning Research},
  year    = {2022},
  volume  = {23},
  number  = {267},
  pages   = {1--6},
  url     = {http://jmlr.org/papers/v23/21-1127.html}
}

Acknowledgment

Tianshou was previously a reinforcement learning platform based on TensorFlow. You can check out the branch priv for more detail. Many thanks to Haosheng Zou's pioneering work for Tianshou before version 0.1.1.

We would like to thank TSAIL and Institute for Artificial Intelligence, Tsinghua University for providing such an excellent AI research platform.

Description
No description provided
Readme 46 MiB
Languages
Python 100%