shorten quick start
This commit is contained in:
parent
57735ce1b5
commit
a326d30739
97
README.md
97
README.md
@ -115,7 +115,7 @@ You can check out the [documentation](https://tianshou.readthedocs.io) for furth
|
||||
|
||||
## Quick Start
|
||||
|
||||
This is an example of Policy Gradient. You can also run the full script under [test/discrete/test_pg.py](/test/discrete/test_pg.py).
|
||||
This is an example of Deep Q Network. You can also run the full script under [test/discrete/test_dqn.py](/test/discrete/test_dqn.py).
|
||||
|
||||
First, import the relevant packages:
|
||||
|
||||
@ -123,9 +123,9 @@ First, import the relevant packages:
|
||||
import gym, torch, numpy as np, torch.nn as nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
```
|
||||
|
||||
@ -133,46 +133,21 @@ Define some hyper-parameters:
|
||||
|
||||
```python
|
||||
task = 'CartPole-v0'
|
||||
seed = 1626
|
||||
lr = 3e-4
|
||||
lr = 1e-3
|
||||
gamma = 0.9
|
||||
n_step = 3
|
||||
eps_train, eps_test = 0.1, 0.05
|
||||
epoch = 10
|
||||
step_per_epoch = 1000
|
||||
collect_per_step = 10
|
||||
repeat_per_collect = 2
|
||||
target_freq = 320
|
||||
batch_size = 64
|
||||
train_num = 8
|
||||
test_num = 100
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
writer = SummaryWriter('log/pg') # tensorboard is also supported!
|
||||
train_num, test_num = 8, 100
|
||||
buffer_size = 20000
|
||||
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
|
||||
```
|
||||
|
||||
Define the network:
|
||||
|
||||
```python
|
||||
class Net(nn.Module):
|
||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.model = [
|
||||
nn.Linear(np.prod(state_shape), 128),
|
||||
nn.ReLU(inplace=True)]
|
||||
for i in range(layer_num):
|
||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||
if action_shape:
|
||||
self.model += [nn.Linear(128, np.prod(action_shape))]
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, s, state=None, info={}):
|
||||
if not isinstance(s, torch.Tensor):
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
batch = s.shape[0]
|
||||
s = s.view(batch, -1)
|
||||
logits = self.model(s)
|
||||
return logits, state
|
||||
```
|
||||
|
||||
Make envs and fix seed:
|
||||
Make envs:
|
||||
|
||||
```python
|
||||
env = gym.make(task)
|
||||
@ -180,46 +155,68 @@ state_shape = env.observation_space.shape or env.observation_space.n
|
||||
action_shape = env.action_space.shape or env.action_space.n
|
||||
train_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
|
||||
test_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
train_envs.seed(seed)
|
||||
test_envs.seed(seed)
|
||||
```
|
||||
|
||||
Setup policy and collector:
|
||||
Define the network:
|
||||
|
||||
```python
|
||||
net = Net(3, state_shape, action_shape, device).to(device)
|
||||
class Net(nn.Module):
|
||||
def __init__(self, state_shape, action_shape):
|
||||
super().__init__()
|
||||
self.model = nn.Sequential(*[
|
||||
nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
|
||||
nn.Linear(128, 128), nn.ReLU(inplace=True),
|
||||
nn.Linear(128, 128), nn.ReLU(inplace=True),
|
||||
nn.Linear(128, np.prod(action_shape))
|
||||
])
|
||||
def forward(self, s, state=None, info={}):
|
||||
if not isinstance(s, torch.Tensor):
|
||||
s = torch.tensor(s, dtype=torch.float)
|
||||
batch = s.shape[0]
|
||||
logits = self.model(s.view(batch, -1))
|
||||
return logits, state
|
||||
|
||||
net = Net(state_shape, action_shape)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=lr)
|
||||
policy = PGPolicy(net, optim, torch.distributions.Categorical, gamma)
|
||||
train_collector = Collector(policy, train_envs, ReplayBuffer(20000))
|
||||
```
|
||||
|
||||
Setup policy and collectors:
|
||||
|
||||
```python
|
||||
policy = DQNPolicy(net, optim, gamma, n_step,
|
||||
use_target_network=True, target_update_freq=target_freq)
|
||||
train_collector = Collector(policy, train_envs, ReplayBuffer(buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
```
|
||||
|
||||
Let's train it:
|
||||
|
||||
```python
|
||||
result = onpolicy_trainer(policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, repeat_per_collect, test_num, batch_size, stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer)
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step,
|
||||
test_num, batch_size, train_fn=lambda e: policy.set_eps(eps_train),
|
||||
test_fn=lambda e: policy.set_eps(eps_test),
|
||||
stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer, task=task)
|
||||
```
|
||||
|
||||
Saving / loading trained policy (it's exactly the same as PyTorch nn.module):
|
||||
|
||||
```python
|
||||
torch.save(policy.state_dict(), 'pg.pth')
|
||||
policy.load_state_dict(torch.load('pg.pth', map_location=device))
|
||||
torch.save(policy.state_dict(), 'dqn.pth')
|
||||
policy.load_state_dict(torch.load('dqn.pth'))
|
||||
```
|
||||
|
||||
Watch the performance with 35 FPS:
|
||||
|
||||
```python3
|
||||
collecter = Collector(policy, env)
|
||||
collecter.collect(n_episode=1, render=1/35)
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1/35)
|
||||
```
|
||||
|
||||
Looking at the result saved in tensorboard: (on bash script)
|
||||
|
||||
```bash
|
||||
tensorboard --logdir log/pg
|
||||
tensorboard --logdir log/dqn
|
||||
```
|
||||
|
||||
## Citing Tianshou
|
||||
|
Loading…
x
Reference in New Issue
Block a user