This PR closes #938. It introduces all the fundamental concepts and abstractions, and it already covers the majority of the algorithms. It is not a complete and finalised product, however, and we recommend that the high-level API remain in alpha stadium for some time, as already suggested in the issue. The changes in this PR are described on a [wiki page](https://github.com/aai-institute/tianshou/wiki/High-Level-API), a copy of which is provided below. (The original page is perhaps more readable, because it does not render line breaks verbatim.) # Introducing the Tianshou High-Level API The new high-level library was created based on object-oriented design principles with two primary design goals: * **ease of use** for the end user (without sacrificing generality) This is achieved through: * a single, well-defined point of interaction (`ExperimentBuilder`) which uses declarative semantics, allowing the user to focus on what to do rather than how to do it. * easily injectible parametrisation. For complex parametrisation involving objects, the respective library classes are easily discoverable, keeping the need to browse reference documentation - or, even worse, inspect code or class hierarchies - to an absolute minimium. * reduced points of failure. Because the high-level API is at a higher level of abstraction, where more knowledge is available, we can centrally define reasonable defaults and apply consistency checks in order to ensure that illegal configurations result in meaningful errors (and are completely avoided as long as the users does not modify default behaviour). For example, we can consider interactions between the nature of the action space and the neural networks being used. * **maintainability** for developers This is achieved through: * a modular design with strong separation of concerns * a high level of factorisation, which largely avoids duplication, partly through the use of mixins and multiple inheritance. This invariably makes the code slightly more complex, yet it greatly reduces the lines of code to be written/updated, so it is a reasonable compromise in this case. ## Changeset The entire high-level library is in its own subpackage `tianshou.highlevel` and **almost no changes were made to the original library** in order to support the new APIs. For the most part, only typing-related changes were made, which have aligned type annotations with existing example applications or have made explicit interfaces that were previously implicit. Furthermore, some helper modules were added to the the `tianshou.util` package (all of which were copied from the [sensAI library](https://github.com/jambit/sensAI)). Many example applications were added, based on the existing MuJoCo and Atari examples (see below). ## User-Facing Interface ### User Experience Example To illustrate the UX, consider this video recording (IntelliJ IDEA):  Observe how conveniently relevant classes can be discovered via the IDE's auto-completion function. Discoverability is markedly enhanced by using a prefix-based naming convention, where classes that can be used as parameters use the base class name as a prefix, allowing all potentially relevant subclasses to be straightforwardly auto-completed. ### Declarative Semantics A key design principle for the user-facing interface was to achieve *declarative semantics*, where the user is no longer concerned with generating a lengthy procedure that sequentially constructs components that build upon each other. Instead, the user focuses purely on *declaring* the properties of the learning task he would like to run. * This essentially reduces boiler-plate code to zero, as every part of the code is defining essential, experiment-specific configuration. * This makes it possible to centrally handle interdependent configuration and detect/avoid misspecification. In order to enable the configuration of interdependent objects without requiring the user to instantiate the respective objects sequentially, we heavily employ the *factory pattern*. ### Experiment Builders The end user's primary entry point is an `ExperimentBuilder`, which is specialised for each algorithm. As the name suggests, it uses the builder pattern in order to create an `Experiment` object, which is then used to run the learning task. * At builder construction, the user is required to provide only essential configuration, particularly the environment factory. * The bulk of the algorithm-specific parameters can be provided via an algorithm-specific parameter object. For instance, `PPOExperimentBuilder` has the method `with_ppo_params`, which expects an object of type `PPOParams`. * Parametrisation that requires the provision of more complex interfaces (e.g. were multiple specification variants exist) are handled via dedicated builder methods. For example, for the specification of the critic component in an actor-critic algorithm, the following group of functions is provided: * `with_critic_factory` (where the user can provide any (user-defined) factory for the critic component) * `with_critic_factory_default` (with which the user specifies that the default, `Net`-based critic architecture shall be used and has the option to parametrise it) * `with_critic_factory_use_actor` (with which the user indicates that the critic component shall reuse the preprocessing network from the actor component) #### Examples ##### Minimal Example In the simplest of cases, where the user wants to use the default parametrisation for everything, a user could run a PPO learning task as follows, ```python experiment = PPOExperimentBuilder(MyEnvFactory()).build() experiment.run() ``` where `MyEnvFactory` is a factory for the agent's environment. The default behaviour will adapt depending on whether the factory creates environments with discrete or continuous action spaces. ##### Fully Parametrised MuJoCo Example Importantly, the user still has the option to configure all the details. Consider this example, which is from the high-level version of the `mujoco_ppo` example: ```python log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, batch_size=batch_size, num_train_envs=training_num, num_test_envs=test_num, buffer_size=buffer_size, step_per_collect=step_per_collect, repeat_per_collect=repeat_per_collect, ) env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) experiment = ( PPOExperimentBuilder(env_factory, experiment_config, sampling_config) .with_ppo_params( PPOParams( discount_factor=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, reward_normalization=rew_norm, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, value_clip=value_clip, advantage_normalization=norm_adv, eps_clip=eps_clip, dual_clip=dual_clip, recompute_advantage=recompute_adv, lr=lr, lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) if lr_decay else None, dist_fn=DistributionFunctionFactoryIndependentGaussians(), ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) experiment.run(log_name) ``` This is functionally equivalent to the procedural, low-level example. Compare the scripts here: * [original low-level example](https://github.com/aai-institute/tianshou/blob/feat/high-level-api/examples/mujoco/mujoco_ppo.py) * [new high-level example](https://github.com/aai-institute/tianshou/blob/feat/high-level-api/examples/mujoco/mujoco_ppo_hl.py) In general, find example applications of the high-level API in the `examples/` folder in scripts using the `_hl.py` suffix: * [MuJoCo examples](https://github.com/aai-institute/tianshou/tree/feat/high-level-api/examples/mujoco) * [Atari examples](https://github.com/aai-institute/tianshou/tree/feat/high-level-api/examples/atari) ### Experiments The `Experiment` representation contains * the agent factory , * the environment factory, * further definitions pertaining to storage & logging. An exeriment may be run several times, assigning a name (and corresponding storage location) to each run. #### Persistence and Logging Experiments can be serialized and later be reloaded. ```python experiment = Experiment.from_directory("log/my_experiment") ``` Because the experiment representation is composed purely of configuration and factories, which themselves are composed purely of configuration and factories, persisted objects are compact and do not contain state. Every experiment run produces the following artifacts: * the serialized experiment * the serialized best policy found during training * a log file * (optionally) user-defined data, as the persistence handlers are modular Running a reloaded experiment can optionally resume training of the serialized policy. All relevant objects have meaningful string representations that can appear in logs, which is conveniently achieved through the use of `ToStringMixin` (from sensAI). Its use furthermore prevents string representations of recurring objects from being printed more than once. For example, consider this string representation, which was generated for the fully parametrised PPO experiment from the example above: ``` Experiment[ config=ExperimentConfig( seed=42, device='cuda', policy_restore_directory=None, train=True, watch=True, watch_render=0.0, persistence_base_dir='log', persistence_enabled=True), sampling_config=SamplingConfig[ num_epochs=100, step_per_epoch=30000, batch_size=64, num_train_envs=64, num_test_envs=10, buffer_size=4096, step_per_collect=2048, repeat_per_collect=10, update_per_step=1.0, start_timesteps=0, start_timesteps_random=False, replay_buffer_ignore_obs_next=False, replay_buffer_save_only_last_obs=False, replay_buffer_stack_num=1], env_factory=MujocoEnvFactory[ task=Ant-v4, seed=42, obs_norm=True], agent_factory=PPOAgentFactory[ sampling_config=SamplingConfig[<<], optim_factory=OptimizerFactoryAdam[ weight_decay=0, eps=1e-08, betas=(0.9, 0.999)], policy_wrapper_factory=None, trainer_callbacks=TrainerCallbacks( epoch_callback_train=None, epoch_callback_test=None, stop_callback=None), params=PPOParams[ gae_lambda=0.95, max_batchsize=256, lr=0.0003, lr_scheduler_factory=LRSchedulerFactoryLinear[sampling_config=SamplingConfig[<<]], action_scaling=default, action_bound_method=clip, discount_factor=0.99, reward_normalization=True, deterministic_eval=False, dist_fn=DistributionFunctionFactoryIndependentGaussians[], vf_coef=0.25, ent_coef=0.0, max_grad_norm=0.5, eps_clip=0.2, dual_clip=None, value_clip=False, advantage_normalization=False, recompute_advantage=True], actor_factory=ActorFactoryTransientStorageDecorator[ actor_factory=ActorFactoryDefault[ continuous_actor_type=ContinuousActorType.GAUSSIAN, continuous_unbounded=True, continuous_conditioned_sigma=False, hidden_sizes=[64, 64], hidden_activation=<class 'torch.nn.modules.activation.Tanh'>, discrete_softmax=True]], critic_factory=CriticFactoryDefault[ hidden_sizes=[64, 64], hidden_activation=<class 'torch.nn.modules.activation.Tanh'>], critic_use_action=False], logger_factory=LoggerFactoryDefault[ logger_type=tensorboard, wandb_project=None], env_config=None] ``` ## Library Developer Perspective The presentation thus far has focussed on the user's perspective. From the perspective of a Tianshou developer, it is important that the high-level API be clearly structured and maintainable. Here are the most relevant representations: * **Policy parameters** are represented as dataclasses (base class `Params`). The goal is for the parameters to be ultimately passed to the corresponding policy class (e.g. `PPOParams` contains parameters for `PPOPolicy`). * **Parameter transformation**: In part, the parameter dataclass attributes already correspond directly to policy class parameters. However, because the high-level interface must, in many cases, abstract away from the low-level interface, we establish the notion of a `ParamTransformer`, which transforms one or more parameters into the form that is required by the policy class: The idea is that the dictionary representation of the dataclass is successively transformed via `ParamTransformer`s such that the resulting dictionary can ultimately be used as keyword arguments for the policy. To achieve maintainability, the declaration of parameter transformations is colocated with the parameters they affect. Tests ensure that naming issues are detected. * **Composition and inheritance**: We use inheritance and mixins to reduce duplication. * **Factories** are an essential principle of the library. Because the creation of objects may depend on objects that are not yet created, a declarative approach necessitates that we transition from the objects themselves to factories. * The `EnvFactory` was already mentioned above, as it is a user-facing abstraction. Its purpose is to create the (vectorized) `Environments` that will be used in the experiments. * An `AgentFactory` is the central component that creates the policy, the trainer as well as the necessary collectors. To support a new type of policy, a subclass that handles the policy creation is required. In turn, the main task when implementing a new algorithm-specific `ExperimentBuilder` is the creation of the corresponding `AgentFactory`. * Several types of factories serve to parametrize policies and training processes, e.g. * `OptimizerFactory` for the creation of torch optimizers * `ActorFactory` for the creation of actor models * `CriticFactory` for the creation of critic models * `IntermediateModuleFactory` for the creation of models that produce intermediate/latent representations * `EnvParamFactory` for the creation of parameters based on properties of the environment * `NoiseFactory` for the creation of `BaseNoise` instances * `DistributionFunctionFactory` for the creation of functions that create torch distributions from tensors * `LRSchedulerFactory` for learning rate schedulers * `PolicyWrapperFactory` for policy wrappers that extend the functionality of the regular policy (e.g. intrinsic curiosity) * `AutoAlphaFactory` for automatically tuned regularization coefficients (as supported by SAC or REDQ) * A `LoggerFactory` handles the creation of the experiment logger, but the default implementation already handles the cases that were used in the examples. * The `ExperimentBuilder` implementations make use of mixins to add common functionality. As mentioned above, the main task in an algorithm-specific specialization is to create the `AgentFactory`.
⚠️️ Dropped support of Gym: Tianshou no longer supports
gym
, and we recommend that you transition to Gymnasium. If you absolutely have to use gym, you can try using Shimmy (the compatibility layer), but tianshou provides no guarantees that things will work then.
⚠️️ Current Status: the tianshou master branch is currently under heavy development, moving towards more features, improved interfaces, more documentation, and better compatibility with other RL libraries. You can view the relevant issues in the corresponding milestone Stay tuned! (and expect breaking changes until the release is done)
⚠️️ Installing PyTorch: Because of a problem with pytorch packaging and poetry in current releases, the newest version of pytorch is not included in the tianshou dependencies. You can still install the newest pytorch with
pip
after tianshou was installed withpoetry
. Here is a discussion between torch and poetry devs, who are trying to resolve it.
Tianshou (天授) is a reinforcement learning platform based on pure PyTorch. Unlike several existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include:
- Deep Q-Network (DQN)
- Double DQN
- Dueling DQN
- Branching DQN
- Categorical DQN (C51)
- Rainbow DQN (Rainbow)
- Quantile Regression DQN (QRDQN)
- Implicit Quantile Network (IQN)
- Fully-parameterized Quantile Function (FQF)
- Policy Gradient (PG)
- Natural Policy Gradient (NPG)
- Advantage Actor-Critic (A2C)
- Trust Region Policy Optimization (TRPO)
- Proximal Policy Optimization (PPO)
- Deep Deterministic Policy Gradient (DDPG)
- Twin Delayed DDPG (TD3)
- Soft Actor-Critic (SAC)
- Randomized Ensembled Double Q-Learning (REDQ)
- Discrete Soft Actor-Critic (SAC-Discrete)
- Vanilla Imitation Learning
- Batch-Constrained deep Q-Learning (BCQ)
- Conservative Q-Learning (CQL)
- Twin Delayed DDPG with Behavior Cloning (TD3+BC)
- Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)
- Discrete Conservative Q-Learning (CQL-Discrete)
- Discrete Critic Regularized Regression (CRR-Discrete)
- Generative Adversarial Imitation Learning (GAIL)
- Prioritized Experience Replay (PER)
- Generalized Advantage Estimator (GAE)
- Posterior Sampling Reinforcement Learning (PSRL)
- Intrinsic Curiosity Module (ICM)
- Hindsight Experience Replay (HER)
Here are Tianshou's other features:
- Elegant framework, using few lines of code in the core abstractions
- State-of-the-art MuJoCo benchmark for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms
- Support vectorized environment (synchronous or asynchronous) for all algorithms Usage
- Support super-fast vectorized environment EnvPool for all algorithms Usage
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) Usage
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) Usage
- Support customized training process Usage
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
- Support multi-agent RL Usage
- Support both TensorBoard and W&B log tools
- Support multi-GPU training Usage
- Comprehensive documentation, PEP8 code-style checking, type checking and thorough tests
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
“天授”意指上天所授,引申为与生具有的天赋。天授是强化学习平台,而强化学习算法并不是向人类学习的,所以取“天授”意思是没有老师来教,而是自己通过跟环境不断交互来进行学习。
Installation
Tianshou is currently hosted on PyPI and conda-forge. It requires Python >= 3.11.
You can simply install Tianshou from PyPI with the following command:
$ pip install tianshou
If you use Anaconda or Miniconda, you can install Tianshou from conda-forge through the following command:
$ conda install tianshou -c conda-forge
You can also install with the newest version through GitHub:
$ pip install git+https://github.com/thu-ml/tianshou.git@master --upgrade
After installation, open your python console and type
import tianshou
print(tianshou.__version__)
If no error occurs, you have successfully installed Tianshou.
Documentation
The tutorials and API documentation are hosted on tianshou.readthedocs.io.
The example scripts are under test/ folder and examples/ folder.
中文文档位于 https://tianshou.readthedocs.io/zh/master/。
Why Tianshou?
Comprehensive Functionality
RL Platform | GitHub Stars | # of Alg. (1) | Custom Env | Batch Training | RNN Support | Nested Observation | Backend |
---|---|---|---|---|---|---|---|
Baselines | 9 | ✔️ (gym) | ➖ (2) | ✔️ | ❌ | TF1 | |
Stable-Baselines | 11 | ✔️ (gym) | ➖ (2) | ✔️ | ❌ | TF1 | |
Stable-Baselines3 | 7 (3) | ✔️ (gym) | ➖ (2) | ❌ | ✔️ | PyTorch | |
Ray/RLlib | 16 | ✔️ | ✔️ | ✔️ | ✔️ | TF/PyTorch | |
SpinningUp | 6 | ✔️ (gym) | ➖ (2) | ❌ | ❌ | PyTorch | |
Dopamine | 7 | ❌ | ❌ | ❌ | ❌ | TF/JAX | |
ACME | 14 | ✔️ (dm_env) | ✔️ | ✔️ | ✔️ | TF/JAX | |
keras-rl | 7 | ✔️ (gym) | ❌ | ❌ | ❌ | Keras | |
rlpyt | 11 | ❌ | ✔️ | ✔️ | ✔️ | PyTorch | |
ChainerRL | 18 | ✔️ (gym) | ✔️ | ✔️ | ❌ | Chainer | |
Sample Factory | 1 (4) | ✔️ (gym) | ✔️ | ✔️ | ✔️ | PyTorch | |
Tianshou | 20 | ✔️ (Gymnasium) | ✔️ | ✔️ | ✔️ | PyTorch |
(1): access date: 2021-08-08
(2): not all algorithms support this feature
(3): TQC and QR-DQN in sb3-contrib instead of main repo
(4): super fast APPO!
High quality software engineering standard
RL Platform | Documentation | Code Coverage | Type Hints | Last Update |
---|---|---|---|---|
Baselines | ❌ | ❌ | ❌ | |
Stable-Baselines | ❌ | |||
Stable-Baselines3 | ✔️ | |||
Ray/RLlib | ➖(1) | ✔️ | ||
SpinningUp | ❌ | ❌ | ||
Dopamine | ❌ | ❌ | ||
ACME | ➖(1) | ✔️ | ||
keras-rl | ➖(1) | ❌ | ||
rlpyt | ❌ | |||
ChainerRL | ❌ | |||
Sample Factory | ➖ | ❌ | ||
Tianshou | ✔️ |
(1): it has continuous integration but the coverage rate is not available
Reproducible and High Quality Result
Tianshou has its tests. Different from other platforms, the tests include the full agent training procedure for all of the implemented algorithms. It would be failed once if it could not train an agent to perform well enough on limited epochs on toy scenarios. The tests secure the reproducibility of our platform. Check out the GitHub Actions page for more detail.
The Atari/Mujoco benchmark results are under examples/atari/ and examples/mujoco/ folders. Our Mujoco result can beat most of existing benchmarks.
Modularized Policy
We decouple all algorithms roughly into the following parts:
__init__
: initialize the policy;forward
: to compute actions over given observations;process_buffer
: process initial buffer, useful for some offline learning algorithmsprocess_fn
: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);learn
: to learn from a given batch data;post_process_fn
: to update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight);update
: the main interface for training, i.e.,process_fn -> learn -> post_process_fn
.
Within this API, we can interact with different policies conveniently.
Quick Start
This is an example of Deep Q Network. You can also run the full script at test/discrete/test_dqn.py.
First, import some relevant packages:
import gymnasium as gym
import torch, numpy as np, torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import tianshou as ts
Define some hyper-parameters:
task = 'CartPole-v0'
lr, epoch, batch_size = 1e-3, 10, 64
train_num, test_num = 10, 100
gamma, n_step, target_freq = 0.9, 3, 320
buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10
logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn')) # TensorBoard is supported!
# For other loggers: https://tianshou.readthedocs.io/en/master/tutorials/logger.html
Make environments:
# you can also try with SubprocVectorEnv
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
Define the network:
from tianshou.utils.net.common import Net
# you can define other net by following the API:
# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network
env = gym.make(task)
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
optim = torch.optim.Adam(net.parameters(), lr=lr)
Setup policy and collectors:
policy = ts.policy.DQNPolicy(
model=net,
optim=optim,
gamma=gamma,
action_space=env.action_space,
estimate_space=n_step,
target_update_freq=target_freq
)
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) # because DQN uses epsilon-greedy method
Let's train it:
result = ts.trainer.OffpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=epoch,
step_per_epoch=step_per_epoch,
step_per_collect=step_per_collect,
episode_per_test=test_num,
batch_size=batch_size,
update_per_step=update_per_step=1 / step_per_collect,
train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
logger=logger,
).run()
print(f'Finished training! Use {result["duration"]}')
Save / load the trained policy (it's exactly the same as PyTorch nn.module
):
torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('dqn.pth'))
Watch the performance with 35 FPS:
policy.eval()
policy.set_eps(eps_test)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)
Look at the result saved in tensorboard: (with bash script in your terminal)
$ tensorboard --logdir log/dqn
You can check out the documentation for advanced usage.
It's worth a try: here is a test on a laptop (i7-8750H + GTX1060). It only uses 3 seconds for training an agent based on vanilla policy gradient on the CartPole-v0 task: (seed may be different across different platform and device)
$ python3 test/discrete/test_pg.py --seed 0 --render 0.03

Contributing
Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out this link.
Citing Tianshou
If you find Tianshou useful, please cite it in your publications.
@article{tianshou,
author = {Jiayi Weng and Huayu Chen and Dong Yan and Kaichao You and Alexis Duburcq and Minghao Zhang and Yi Su and Hang Su and Jun Zhu},
title = {Tianshou: A Highly Modularized Deep Reinforcement Learning Library},
journal = {Journal of Machine Learning Research},
year = {2022},
volume = {23},
number = {267},
pages = {1--6},
url = {http://jmlr.org/papers/v23/21-1127.html}
}
Acknowledgment
Tianshou is supported by appliedAI Institute for Europe, who is committed to providing long-term support and development.
Tianshou was previously a reinforcement learning platform based on TensorFlow. You can check out the branch priv
for more detail. Many thanks to Haosheng Zou's pioneering work for Tianshou before version 0.1.1.
We would like to thank TSAIL and Institute for Artificial Intelligence, Tsinghua University for providing such an excellent AI research platform.