Update procedural example in README

This commit is contained in:
Dominik Jain 2024-01-12 13:47:12 +01:00
parent 62d58faa02
commit 2c72171fca

View File

@ -305,7 +305,7 @@ First, import some relevant packages:
```python ```python
import gymnasium as gym import gymnasium as gym
import torch, numpy as np, torch.nn as nn import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import tianshou as ts import tianshou as ts
``` ```
@ -313,7 +313,7 @@ import tianshou as ts
Define some hyper-parameters: Define some hyper-parameters:
```python ```python
task = 'CartPole-v0' task = 'CartPole-v1'
lr, epoch, batch_size = 1e-3, 10, 64 lr, epoch, batch_size = 1e-3, 10, 64
train_num, test_num = 10, 100 train_num, test_num = 10, 100
gamma, n_step, target_freq = 0.9, 3, 320 gamma, n_step, target_freq = 0.9, 3, 320
@ -338,7 +338,7 @@ Define the network:
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
# you can define other net by following the API: # you can define other net by following the API:
# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network # https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network
env = gym.make(task) env = gym.make(task, render_mode="human")
state_shape = env.observation_space.shape or env.observation_space.n state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
@ -378,7 +378,7 @@ result = ts.trainer.OffpolicyTrainer(
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
logger=logger, logger=logger,
).run() ).run()
print(f'Finished training! Use {result["duration"]}') print(f"Finished training in {result.timing.total_time} seconds")
``` ```
Save / load the trained policy (it's exactly the same as PyTorch `nn.module`): Save / load the trained policy (it's exactly the same as PyTorch `nn.module`):