shorten quick start

This commit is contained in:
Trinkle23897 2020-03-28 22:40:47 +08:00
parent 57735ce1b5
commit a326d30739

109
README.md
View File

@ -115,7 +115,7 @@ You can check out the [documentation](https://tianshou.readthedocs.io) for furth
## Quick Start ## Quick Start
This is an example of Policy Gradient. You can also run the full script under [test/discrete/test_pg.py](/test/discrete/test_pg.py). This is an example of Deep Q Network. You can also run the full script under [test/discrete/test_dqn.py](/test/discrete/test_dqn.py).
First, import the relevant packages: First, import the relevant packages:
@ -123,56 +123,31 @@ First, import the relevant packages:
import gym, torch, numpy as np, torch.nn as nn import gym, torch, numpy as np, torch.nn as nn
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PGPolicy from tianshou.policy import DQNPolicy
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
``` ```
Define some hyper-parameters: Define some hyper-parameters:
```python ```python
task = 'CartPole-v0' task = 'CartPole-v0'
seed = 1626 lr = 1e-3
lr = 3e-4 gamma = 0.9
gamma = 0.9 n_step = 3
epoch = 10 eps_train, eps_test = 0.1, 0.05
step_per_epoch = 1000 epoch = 10
collect_per_step = 10 step_per_epoch = 1000
repeat_per_collect = 2 collect_per_step = 10
batch_size = 64 target_freq = 320
train_num = 8 batch_size = 64
test_num = 100 train_num, test_num = 8, 100
device = 'cuda' if torch.cuda.is_available() else 'cpu' buffer_size = 20000
writer = SummaryWriter('log/pg') # tensorboard is also supported! writer = SummaryWriter('log/dqn') # tensorboard is also supported!
``` ```
Define the network: Make envs:
```python
class Net(nn.Module):
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
super().__init__()
self.device = device
self.model = [
nn.Linear(np.prod(state_shape), 128),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
if action_shape:
self.model += [nn.Linear(128, np.prod(action_shape))]
self.model = nn.Sequential(*self.model)
def forward(self, s, state=None, info={}):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)
return logits, state
```
Make envs and fix seed:
```python ```python
env = gym.make(task) env = gym.make(task)
@ -180,46 +155,68 @@ state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n action_shape = env.action_space.shape or env.action_space.n
train_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(train_num)]) train_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(test_num)]) test_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
np.random.seed(seed)
torch.manual_seed(seed)
train_envs.seed(seed)
test_envs.seed(seed)
``` ```
Setup policy and collector: Define the network:
```python ```python
net = Net(3, state_shape, action_shape, device).to(device) class Net(nn.Module):
def __init__(self, state_shape, action_shape):
super().__init__()
self.model = nn.Sequential(*[
nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
nn.Linear(128, np.prod(action_shape))
])
def forward(self, s, state=None, info={}):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, dtype=torch.float)
batch = s.shape[0]
logits = self.model(s.view(batch, -1))
return logits, state
net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=lr) optim = torch.optim.Adam(net.parameters(), lr=lr)
policy = PGPolicy(net, optim, torch.distributions.Categorical, gamma) ```
train_collector = Collector(policy, train_envs, ReplayBuffer(20000))
Setup policy and collectors:
```python
policy = DQNPolicy(net, optim, gamma, n_step,
use_target_network=True, target_update_freq=target_freq)
train_collector = Collector(policy, train_envs, ReplayBuffer(buffer_size))
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
``` ```
Let's train it: Let's train it:
```python ```python
result = onpolicy_trainer(policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, repeat_per_collect, test_num, batch_size, stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer) result = offpolicy_trainer(
policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step,
test_num, batch_size, train_fn=lambda e: policy.set_eps(eps_train),
test_fn=lambda e: policy.set_eps(eps_test),
stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer, task=task)
``` ```
Saving / loading trained policy (it's exactly the same as PyTorch nn.module): Saving / loading trained policy (it's exactly the same as PyTorch nn.module):
```python ```python
torch.save(policy.state_dict(), 'pg.pth') torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('pg.pth', map_location=device)) policy.load_state_dict(torch.load('dqn.pth'))
``` ```
Watch the performance with 35 FPS: Watch the performance with 35 FPS:
```python3 ```python3
collecter = Collector(policy, env) collector = Collector(policy, env)
collecter.collect(n_episode=1, render=1/35) collector.collect(n_episode=1, render=1/35)
``` ```
Looking at the result saved in tensorboard: (on bash script) Looking at the result saved in tensorboard: (on bash script)
```bash ```bash
tensorboard --logdir log/pg tensorboard --logdir log/dqn
``` ```
## Citing Tianshou ## Citing Tianshou