Improve README, minor changes in procedural example (#1068)
This commit is contained in:
		
						commit
						fdb69f1273
					
				
							
								
								
									
										44
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										44
									
								
								README.md
									
									
									
									
									
								
							@ -6,10 +6,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
[](https://pypi.org/project/tianshou/) [](https://github.com/conda-forge/tianshou-feedstock) [](https://tianshou.readthedocs.io/en/master) [](https://tianshou.readthedocs.io/zh/master/) [](https://github.com/thu-ml/tianshou/actions) [](https://codecov.io/gh/thu-ml/tianshou) [](https://github.com/thu-ml/tianshou/issues) [](https://github.com/thu-ml/tianshou/stargazers) [](https://github.com/thu-ml/tianshou/network) [](https://github.com/thu-ml/tianshou/blob/master/LICENSE)
 | 
					[](https://pypi.org/project/tianshou/) [](https://github.com/conda-forge/tianshou-feedstock) [](https://tianshou.readthedocs.io/en/master) [](https://tianshou.readthedocs.io/zh/master/) [](https://github.com/thu-ml/tianshou/actions) [](https://codecov.io/gh/thu-ml/tianshou) [](https://github.com/thu-ml/tianshou/issues) [](https://github.com/thu-ml/tianshou/stargazers) [](https://github.com/thu-ml/tianshou/network) [](https://github.com/thu-ml/tianshou/blob/master/LICENSE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
> ⚠️️ **Dropped support of Gym**:
 | 
					> ⚠️️ **Dropped support for Gym**:
 | 
				
			||||||
> Tianshou no longer supports `gym`, and we recommend that you transition to 
 | 
					> Tianshou no longer supports Gym, and we recommend that you transition to 
 | 
				
			||||||
> [Gymnasium](http://github.com/Farama-Foundation/Gymnasium).
 | 
					> [Gymnasium](http://github.com/Farama-Foundation/Gymnasium).
 | 
				
			||||||
> If you absolutely have to use gym, you can try using [Shimmy](https://github.com/Farama-Foundation/Shimmy) 
 | 
					> If you absolutely have to use Gym, you can try using [Shimmy](https://github.com/Farama-Foundation/Shimmy) 
 | 
				
			||||||
> (the compatibility layer), but Tianshou provides no guarantees that things will work then.
 | 
					> (the compatibility layer), but Tianshou provides no guarantees that things will work then.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
> ⚠️️ **Current Status**: the Tianshou master branch is currently under heavy development,
 | 
					> ⚠️️ **Current Status**: the Tianshou master branch is currently under heavy development,
 | 
				
			||||||
@ -179,7 +179,7 @@ Find example scripts in the [test/](https://github.com/thu-ml/tianshou/blob/mast
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
<sup>(4): super fast APPO!</sup>
 | 
					<sup>(4): super fast APPO!</sup>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### High quality software engineering standard
 | 
					### High Software Engineering Standards
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| RL Platform                                                        | Documentation                                                                                                                                                        | Code Coverage                                                                                                                                                                                                                                                                                                                                 | Type Hints         | Last Update                                                                                                       |
 | 
					| RL Platform                                                        | Documentation                                                                                                                                                        | Code Coverage                                                                                                                                                                                                                                                                                                                                 | Type Hints         | Last Update                                                                                                       |
 | 
				
			||||||
| ------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------ | ----------------------------------------------------------------------------------------------------------------- |
 | 
					| ------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------ | ----------------------------------------------------------------------------------------------------------------- |
 | 
				
			||||||
@ -233,8 +233,6 @@ We shall apply the deep Q network (DQN) learning algorithm using both APIs.
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
### High-Level API
 | 
					### High-Level API
 | 
				
			||||||
 | 
					
 | 
				
			||||||
The high-level API requires the extra package `argparse` (by adding 
 | 
					 | 
				
			||||||
`--extras argparse`) to be installed.
 | 
					 | 
				
			||||||
To get started, we need some imports.
 | 
					To get started, we need some imports.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```python
 | 
					```python
 | 
				
			||||||
@ -333,11 +331,15 @@ Here's a run (with the training time cut short):
 | 
				
			|||||||
  <img src="docs/_static/images/discrete_dqn_hl.gif">
 | 
					  <img src="docs/_static/images/discrete_dqn_hl.gif">
 | 
				
			||||||
</p>
 | 
					</p>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Find many further applications of the high-level API in the `examples/` folder;
 | 
				
			||||||
 | 
					look for scripts ending with `_hl.py`.
 | 
				
			||||||
 | 
					Note that most of these examples require the extra package `argparse` 
 | 
				
			||||||
 | 
					(install it by adding `--extras argparse` when invoking poetry).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### Procedural API
 | 
					### Procedural API
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Let us now consider an analogous example in the procedural API. 
 | 
					Let us now consider an analogous example in the procedural API. 
 | 
				
			||||||
Find the full script from which the snippets below were derived at [test/discrete/test_dqn.py](https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_dqn.py).
 | 
					Find the full script in [examples/discrete/discrete_dqn.py](https://github.com/thu-ml/tianshou/blob/master/examples/discrete/discrete_dqn.py).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
First, import some relevant packages:
 | 
					First, import some relevant packages:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -358,24 +360,30 @@ gamma, n_step, target_freq = 0.9, 3, 320
 | 
				
			|||||||
buffer_size = 20000
 | 
					buffer_size = 20000
 | 
				
			||||||
eps_train, eps_test = 0.1, 0.05
 | 
					eps_train, eps_test = 0.1, 0.05
 | 
				
			||||||
step_per_epoch, step_per_collect = 10000, 10
 | 
					step_per_epoch, step_per_collect = 10000, 10
 | 
				
			||||||
logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn'))  # TensorBoard is supported!
 | 
					```
 | 
				
			||||||
# For other loggers: https://tianshou.readthedocs.io/en/master/01_tutorials/05_logger.html
 | 
					
 | 
				
			||||||
 | 
					Initialize the logger:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn'))
 | 
				
			||||||
 | 
					# For other loggers, see https://tianshou.readthedocs.io/en/master/01_tutorials/05_logger.html
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Make environments:
 | 
					Make environments:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```python
 | 
					```python
 | 
				
			||||||
# you can also try with SubprocVectorEnv
 | 
					# You can also try SubprocVectorEnv, which will use parallelization
 | 
				
			||||||
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
 | 
					train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
 | 
				
			||||||
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
 | 
					test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Define the network:
 | 
					Create the network as well as its optimizer:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```python
 | 
					```python
 | 
				
			||||||
from tianshou.utils.net.common import Net
 | 
					from tianshou.utils.net.common import Net
 | 
				
			||||||
# you can define other net by following the API:
 | 
					
 | 
				
			||||||
# https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network
 | 
					# Note: You can easily define other networks.
 | 
				
			||||||
 | 
					# See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network
 | 
				
			||||||
env = gym.make(task, render_mode="human")
 | 
					env = gym.make(task, render_mode="human")
 | 
				
			||||||
state_shape = env.observation_space.shape or env.observation_space.n
 | 
					state_shape = env.observation_space.shape or env.observation_space.n
 | 
				
			||||||
action_shape = env.action_space.shape or env.action_space.n
 | 
					action_shape = env.action_space.shape or env.action_space.n
 | 
				
			||||||
@ -383,7 +391,7 @@ net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128,
 | 
				
			|||||||
optim = torch.optim.Adam(net.parameters(), lr=lr)
 | 
					optim = torch.optim.Adam(net.parameters(), lr=lr)
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Setup policy and collectors:
 | 
					Set up the policy and collectors:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```python
 | 
					```python
 | 
				
			||||||
policy = ts.policy.DQNPolicy(
 | 
					policy = ts.policy.DQNPolicy(
 | 
				
			||||||
@ -419,14 +427,14 @@ result = ts.trainer.OffpolicyTrainer(
 | 
				
			|||||||
print(f"Finished training in {result.timing.total_time} seconds")
 | 
					print(f"Finished training in {result.timing.total_time} seconds")
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Save / load the trained policy (it's exactly the same as PyTorch `nn.module`):
 | 
					Save/load the trained policy (it's exactly the same as loading a `torch.nn.module`):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```python
 | 
					```python
 | 
				
			||||||
torch.save(policy.state_dict(), 'dqn.pth')
 | 
					torch.save(policy.state_dict(), 'dqn.pth')
 | 
				
			||||||
policy.load_state_dict(torch.load('dqn.pth'))
 | 
					policy.load_state_dict(torch.load('dqn.pth'))
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Watch the performance with 35 FPS:
 | 
					Watch the agent with 35 FPS:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```python
 | 
					```python
 | 
				
			||||||
policy.eval()
 | 
					policy.eval()
 | 
				
			||||||
@ -435,13 +443,13 @@ collector = ts.data.Collector(policy, env, exploration_noise=True)
 | 
				
			|||||||
collector.collect(n_episode=1, render=1 / 35)
 | 
					collector.collect(n_episode=1, render=1 / 35)
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Look at the result saved in tensorboard: (with bash script in your terminal)
 | 
					Inspect the data saved in TensorBoard:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
$ tensorboard --logdir log/dqn
 | 
					$ tensorboard --logdir log/dqn
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
You can check out the [documentation](https://tianshou.readthedocs.io) for advanced usage.
 | 
					Please read the [documentation](https://tianshou.readthedocs.io) for advanced usage.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Contributing
 | 
					## Contributing
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -1,11 +1,8 @@
 | 
				
			|||||||
from typing import cast
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import gymnasium as gym
 | 
					import gymnasium as gym
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch.utils.tensorboard import SummaryWriter
 | 
					from torch.utils.tensorboard import SummaryWriter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import tianshou as ts
 | 
					import tianshou as ts
 | 
				
			||||||
from tianshou.utils.space_info import SpaceInfo
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main() -> None:
 | 
					def main() -> None:
 | 
				
			||||||
@ -16,22 +13,21 @@ def main() -> None:
 | 
				
			|||||||
    buffer_size = 20000
 | 
					    buffer_size = 20000
 | 
				
			||||||
    eps_train, eps_test = 0.1, 0.05
 | 
					    eps_train, eps_test = 0.1, 0.05
 | 
				
			||||||
    step_per_epoch, step_per_collect = 10000, 10
 | 
					    step_per_epoch, step_per_collect = 10000, 10
 | 
				
			||||||
    logger = ts.utils.TensorboardLogger(SummaryWriter("log/dqn"))  # TensorBoard is supported!
 | 
					 | 
				
			||||||
    # For other loggers: https://tianshou.readthedocs.io/en/master/tutorials/logger.html
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # you can also try with SubprocVectorEnv
 | 
					    logger = ts.utils.TensorboardLogger(SummaryWriter("log/dqn"))  # TensorBoard is supported!
 | 
				
			||||||
 | 
					    # For other loggers, see https://tianshou.readthedocs.io/en/master/tutorials/logger.html
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # You can also try SubprocVectorEnv, which will use parallelization
 | 
				
			||||||
    train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
 | 
					    train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
 | 
				
			||||||
    test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
 | 
					    test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    from tianshou.utils.net.common import Net
 | 
					    from tianshou.utils.net.common import Net
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # you can define other net by following the API:
 | 
					    # Note: You can easily define other networks.
 | 
				
			||||||
    # https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network
 | 
					    # See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network
 | 
				
			||||||
    env = gym.make(task, render_mode="human")
 | 
					    env = gym.make(task, render_mode="human")
 | 
				
			||||||
    env.action_space = cast(gym.spaces.Discrete, env.action_space)
 | 
					    state_shape = env.observation_space.shape or env.observation_space.n
 | 
				
			||||||
    space_info = SpaceInfo.from_env(env)
 | 
					    action_shape = env.action_space.shape or env.action_space.n
 | 
				
			||||||
    state_shape = space_info.observation_info.obs_shape
 | 
					 | 
				
			||||||
    action_shape = space_info.action_info.action_shape
 | 
					 | 
				
			||||||
    net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
 | 
					    net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
 | 
				
			||||||
    optim = torch.optim.Adam(net.parameters(), lr=lr)
 | 
					    optim = torch.optim.Adam(net.parameters(), lr=lr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user