Update procedural example in README
This commit is contained in:
parent
62d58faa02
commit
2c72171fca
@ -305,7 +305,7 @@ First, import some relevant packages:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import torch, numpy as np, torch.nn as nn
|
import torch
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
import tianshou as ts
|
import tianshou as ts
|
||||||
```
|
```
|
||||||
@ -313,7 +313,7 @@ import tianshou as ts
|
|||||||
Define some hyper-parameters:
|
Define some hyper-parameters:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
task = 'CartPole-v0'
|
task = 'CartPole-v1'
|
||||||
lr, epoch, batch_size = 1e-3, 10, 64
|
lr, epoch, batch_size = 1e-3, 10, 64
|
||||||
train_num, test_num = 10, 100
|
train_num, test_num = 10, 100
|
||||||
gamma, n_step, target_freq = 0.9, 3, 320
|
gamma, n_step, target_freq = 0.9, 3, 320
|
||||||
@ -338,7 +338,7 @@ Define the network:
|
|||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
# you can define other net by following the API:
|
# you can define other net by following the API:
|
||||||
# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network
|
# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network
|
||||||
env = gym.make(task)
|
env = gym.make(task, render_mode="human")
|
||||||
state_shape = env.observation_space.shape or env.observation_space.n
|
state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
action_shape = env.action_space.shape or env.action_space.n
|
action_shape = env.action_space.shape or env.action_space.n
|
||||||
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
|
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
|
||||||
@ -378,7 +378,7 @@ result = ts.trainer.OffpolicyTrainer(
|
|||||||
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
|
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
).run()
|
).run()
|
||||||
print(f'Finished training! Use {result["duration"]}')
|
print(f"Finished training in {result.timing.total_time} seconds")
|
||||||
```
|
```
|
||||||
|
|
||||||
Save / load the trained policy (it's exactly the same as PyTorch `nn.module`):
|
Save / load the trained policy (it's exactly the same as PyTorch `nn.module`):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user