shorten quick start

This commit is contained in:
Trinkle23897 2020-03-28 22:40:47 +08:00
parent 57735ce1b5
commit a326d30739

View File

@ -115,7 +115,7 @@ You can check out the [documentation](https://tianshou.readthedocs.io) for furth
## Quick Start
This is an example of Policy Gradient. You can also run the full script under [test/discrete/test_pg.py](/test/discrete/test_pg.py).
This is an example of Deep Q Network. You can also run the full script under [test/discrete/test_dqn.py](/test/discrete/test_dqn.py).
First, import the relevant packages:
@ -123,9 +123,9 @@ First, import the relevant packages:
import gym, torch, numpy as np, torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PGPolicy
from tianshou.policy import DQNPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import onpolicy_trainer
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
```
@ -133,46 +133,21 @@ Define some hyper-parameters:
```python
task = 'CartPole-v0'
seed = 1626
lr = 3e-4
lr = 1e-3
gamma = 0.9
n_step = 3
eps_train, eps_test = 0.1, 0.05
epoch = 10
step_per_epoch = 1000
collect_per_step = 10
repeat_per_collect = 2
target_freq = 320
batch_size = 64
train_num = 8
test_num = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
writer = SummaryWriter('log/pg') # tensorboard is also supported!
train_num, test_num = 8, 100
buffer_size = 20000
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
```
Define the network:
```python
class Net(nn.Module):
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
super().__init__()
self.device = device
self.model = [
nn.Linear(np.prod(state_shape), 128),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
if action_shape:
self.model += [nn.Linear(128, np.prod(action_shape))]
self.model = nn.Sequential(*self.model)
def forward(self, s, state=None, info={}):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)
return logits, state
```
Make envs and fix seed:
Make envs:
```python
env = gym.make(task)
@ -180,46 +155,68 @@ state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
train_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
np.random.seed(seed)
torch.manual_seed(seed)
train_envs.seed(seed)
test_envs.seed(seed)
```
Setup policy and collector:
Define the network:
```python
net = Net(3, state_shape, action_shape, device).to(device)
class Net(nn.Module):
def __init__(self, state_shape, action_shape):
super().__init__()
self.model = nn.Sequential(*[
nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
nn.Linear(128, np.prod(action_shape))
])
def forward(self, s, state=None, info={}):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, dtype=torch.float)
batch = s.shape[0]
logits = self.model(s.view(batch, -1))
return logits, state
net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=lr)
policy = PGPolicy(net, optim, torch.distributions.Categorical, gamma)
train_collector = Collector(policy, train_envs, ReplayBuffer(20000))
```
Setup policy and collectors:
```python
policy = DQNPolicy(net, optim, gamma, n_step,
use_target_network=True, target_update_freq=target_freq)
train_collector = Collector(policy, train_envs, ReplayBuffer(buffer_size))
test_collector = Collector(policy, test_envs)
```
Let's train it:
```python
result = onpolicy_trainer(policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, repeat_per_collect, test_num, batch_size, stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer)
result = offpolicy_trainer(
policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step,
test_num, batch_size, train_fn=lambda e: policy.set_eps(eps_train),
test_fn=lambda e: policy.set_eps(eps_test),
stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer, task=task)
```
Saving / loading trained policy (it's exactly the same as PyTorch nn.module):
```python
torch.save(policy.state_dict(), 'pg.pth')
policy.load_state_dict(torch.load('pg.pth', map_location=device))
torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('dqn.pth'))
```
Watch the performance with 35 FPS:
```python3
collecter = Collector(policy, env)
collecter.collect(n_episode=1, render=1/35)
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1/35)
```
Looking at the result saved in tensorboard: (on bash script)
```bash
tensorboard --logdir log/pg
tensorboard --logdir log/dqn
```
## Citing Tianshou