Improve README, minor changes in procedural example

This commit is contained in:
Dominik Jain 2024-03-02 13:17:15 +01:00 committed by Michael Panchenko
parent 1aee41fa9c
commit b6b2c95ac7
2 changed files with 34 additions and 30 deletions

View File

@ -6,10 +6,10 @@
[![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/) [![Conda](https://img.shields.io/conda/vn/conda-forge/tianshou)](https://github.com/conda-forge/tianshou-feedstock) [![Read the Docs](https://img.shields.io/readthedocs/tianshou)](https://tianshou.readthedocs.io/en/master) [![Read the Docs](https://img.shields.io/readthedocs/tianshou-docs-zh-cn?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://tianshou.readthedocs.io/zh/master/) [![Unittest](https://github.com/thu-ml/tianshou/actions/workflows/pytest.yml/badge.svg)](https://github.com/thu-ml/tianshou/actions) [![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou) [![GitHub issues](https://img.shields.io/github/issues/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/issues) [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) [![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network) [![GitHub license](https://img.shields.io/github/license/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/blob/master/LICENSE) [![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/) [![Conda](https://img.shields.io/conda/vn/conda-forge/tianshou)](https://github.com/conda-forge/tianshou-feedstock) [![Read the Docs](https://img.shields.io/readthedocs/tianshou)](https://tianshou.readthedocs.io/en/master) [![Read the Docs](https://img.shields.io/readthedocs/tianshou-docs-zh-cn?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://tianshou.readthedocs.io/zh/master/) [![Unittest](https://github.com/thu-ml/tianshou/actions/workflows/pytest.yml/badge.svg)](https://github.com/thu-ml/tianshou/actions) [![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou) [![GitHub issues](https://img.shields.io/github/issues/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/issues) [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) [![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network) [![GitHub license](https://img.shields.io/github/license/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/blob/master/LICENSE)
> ⚠️️ **Dropped support of Gym**: > ⚠️️ **Dropped support for Gym**:
> Tianshou no longer supports `gym`, and we recommend that you transition to > Tianshou no longer supports Gym, and we recommend that you transition to
> [Gymnasium](http://github.com/Farama-Foundation/Gymnasium). > [Gymnasium](http://github.com/Farama-Foundation/Gymnasium).
> If you absolutely have to use gym, you can try using [Shimmy](https://github.com/Farama-Foundation/Shimmy) > If you absolutely have to use Gym, you can try using [Shimmy](https://github.com/Farama-Foundation/Shimmy)
> (the compatibility layer), but Tianshou provides no guarantees that things will work then. > (the compatibility layer), but Tianshou provides no guarantees that things will work then.
> ⚠️️ **Current Status**: the Tianshou master branch is currently under heavy development, > ⚠️️ **Current Status**: the Tianshou master branch is currently under heavy development,
@ -179,7 +179,7 @@ Find example scripts in the [test/](https://github.com/thu-ml/tianshou/blob/mast
<sup>(4): super fast APPO!</sup> <sup>(4): super fast APPO!</sup>
### High quality software engineering standard ### High Software Engineering Standards
| RL Platform | Documentation | Code Coverage | Type Hints | Last Update | | RL Platform | Documentation | Code Coverage | Type Hints | Last Update |
| ------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------ | ----------------------------------------------------------------------------------------------------------------- | | ------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------ | ----------------------------------------------------------------------------------------------------------------- |
@ -233,8 +233,6 @@ We shall apply the deep Q network (DQN) learning algorithm using both APIs.
### High-Level API ### High-Level API
The high-level API requires the extra package `argparse` (by adding
`--extras argparse`) to be installed.
To get started, we need some imports. To get started, we need some imports.
```python ```python
@ -333,11 +331,15 @@ Here's a run (with the training time cut short):
<img src="docs/_static/images/discrete_dqn_hl.gif"> <img src="docs/_static/images/discrete_dqn_hl.gif">
</p> </p>
Find many further applications of the high-level API in the `examples/` folder;
look for scripts ending with `_hl.py`.
Note that most of these examples require the extra package `argparse`
(install it by adding `--extras argparse` when invoking poetry).
### Procedural API ### Procedural API
Let us now consider an analogous example in the procedural API. Let us now consider an analogous example in the procedural API.
Find the full script from which the snippets below were derived at [test/discrete/test_dqn.py](https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_dqn.py). Find the full script in [examples/discrete/discrete_dqn.py](https://github.com/thu-ml/tianshou/blob/master/examples/discrete/discrete_dqn.py).
First, import some relevant packages: First, import some relevant packages:
@ -358,24 +360,30 @@ gamma, n_step, target_freq = 0.9, 3, 320
buffer_size = 20000 buffer_size = 20000
eps_train, eps_test = 0.1, 0.05 eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10 step_per_epoch, step_per_collect = 10000, 10
logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn')) # TensorBoard is supported! ```
# For other loggers: https://tianshou.readthedocs.io/en/master/01_tutorials/05_logger.html
Initialize the logger:
```python
logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn'))
# For other loggers, see https://tianshou.readthedocs.io/en/master/01_tutorials/05_logger.html
``` ```
Make environments: Make environments:
```python ```python
# you can also try with SubprocVectorEnv # You can also try SubprocVectorEnv, which will use parallelization
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)]) train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)]) test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
``` ```
Define the network: Create the network as well as its optimizer:
```python ```python
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
# you can define other net by following the API:
# https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network # Note: You can easily define other networks.
# See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network
env = gym.make(task, render_mode="human") env = gym.make(task, render_mode="human")
state_shape = env.observation_space.shape or env.observation_space.n state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n action_shape = env.action_space.shape or env.action_space.n
@ -383,7 +391,7 @@ net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128,
optim = torch.optim.Adam(net.parameters(), lr=lr) optim = torch.optim.Adam(net.parameters(), lr=lr)
``` ```
Setup policy and collectors: Set up the policy and collectors:
```python ```python
policy = ts.policy.DQNPolicy( policy = ts.policy.DQNPolicy(
@ -419,14 +427,14 @@ result = ts.trainer.OffpolicyTrainer(
print(f"Finished training in {result.timing.total_time} seconds") print(f"Finished training in {result.timing.total_time} seconds")
``` ```
Save / load the trained policy (it's exactly the same as PyTorch `nn.module`): Save/load the trained policy (it's exactly the same as loading a `torch.nn.module`):
```python ```python
torch.save(policy.state_dict(), 'dqn.pth') torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('dqn.pth')) policy.load_state_dict(torch.load('dqn.pth'))
``` ```
Watch the performance with 35 FPS: Watch the agent with 35 FPS:
```python ```python
policy.eval() policy.eval()
@ -435,13 +443,13 @@ collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35) collector.collect(n_episode=1, render=1 / 35)
``` ```
Look at the result saved in tensorboard: (with bash script in your terminal) Inspect the data saved in TensorBoard:
```bash ```bash
$ tensorboard --logdir log/dqn $ tensorboard --logdir log/dqn
``` ```
You can check out the [documentation](https://tianshou.readthedocs.io) for advanced usage. Please read the [documentation](https://tianshou.readthedocs.io) for advanced usage.
## Contributing ## Contributing

View File

@ -1,11 +1,8 @@
from typing import cast
import gymnasium as gym import gymnasium as gym
import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import tianshou as ts import tianshou as ts
from tianshou.utils.space_info import SpaceInfo
def main() -> None: def main() -> None:
@ -16,22 +13,21 @@ def main() -> None:
buffer_size = 20000 buffer_size = 20000
eps_train, eps_test = 0.1, 0.05 eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10 step_per_epoch, step_per_collect = 10000, 10
logger = ts.utils.TensorboardLogger(SummaryWriter("log/dqn")) # TensorBoard is supported!
# For other loggers: https://tianshou.readthedocs.io/en/master/tutorials/logger.html
# you can also try with SubprocVectorEnv logger = ts.utils.TensorboardLogger(SummaryWriter("log/dqn")) # TensorBoard is supported!
# For other loggers, see https://tianshou.readthedocs.io/en/master/tutorials/logger.html
# You can also try SubprocVectorEnv, which will use parallelization
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)]) train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)]) test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
# you can define other net by following the API: # Note: You can easily define other networks.
# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network # See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network
env = gym.make(task, render_mode="human") env = gym.make(task, render_mode="human")
env.action_space = cast(gym.spaces.Discrete, env.action_space) state_shape = env.observation_space.shape or env.observation_space.n
space_info = SpaceInfo.from_env(env) action_shape = env.action_space.shape or env.action_space.n
state_shape = space_info.observation_info.obs_shape
action_shape = space_info.action_info.action_shape
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
optim = torch.optim.Adam(net.parameters(), lr=lr) optim = torch.optim.Adam(net.parameters(), lr=lr)