shorten quick start
This commit is contained in:
parent
57735ce1b5
commit
a326d30739
109
README.md
109
README.md
@ -115,7 +115,7 @@ You can check out the [documentation](https://tianshou.readthedocs.io) for furth
|
|||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
This is an example of Policy Gradient. You can also run the full script under [test/discrete/test_pg.py](/test/discrete/test_pg.py).
|
This is an example of Deep Q Network. You can also run the full script under [test/discrete/test_dqn.py](/test/discrete/test_dqn.py).
|
||||||
|
|
||||||
First, import the relevant packages:
|
First, import the relevant packages:
|
||||||
|
|
||||||
@ -123,56 +123,31 @@ First, import the relevant packages:
|
|||||||
import gym, torch, numpy as np, torch.nn as nn
|
import gym, torch, numpy as np, torch.nn as nn
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import PGPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
```
|
```
|
||||||
|
|
||||||
Define some hyper-parameters:
|
Define some hyper-parameters:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
task = 'CartPole-v0'
|
task = 'CartPole-v0'
|
||||||
seed = 1626
|
lr = 1e-3
|
||||||
lr = 3e-4
|
gamma = 0.9
|
||||||
gamma = 0.9
|
n_step = 3
|
||||||
epoch = 10
|
eps_train, eps_test = 0.1, 0.05
|
||||||
step_per_epoch = 1000
|
epoch = 10
|
||||||
collect_per_step = 10
|
step_per_epoch = 1000
|
||||||
repeat_per_collect = 2
|
collect_per_step = 10
|
||||||
batch_size = 64
|
target_freq = 320
|
||||||
train_num = 8
|
batch_size = 64
|
||||||
test_num = 100
|
train_num, test_num = 8, 100
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
buffer_size = 20000
|
||||||
writer = SummaryWriter('log/pg') # tensorboard is also supported!
|
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
|
||||||
```
|
```
|
||||||
|
|
||||||
Define the network:
|
Make envs:
|
||||||
|
|
||||||
```python
|
|
||||||
class Net(nn.Module):
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
self.model = [
|
|
||||||
nn.Linear(np.prod(state_shape), 128),
|
|
||||||
nn.ReLU(inplace=True)]
|
|
||||||
for i in range(layer_num):
|
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
|
||||||
if action_shape:
|
|
||||||
self.model += [nn.Linear(128, np.prod(action_shape))]
|
|
||||||
self.model = nn.Sequential(*self.model)
|
|
||||||
|
|
||||||
def forward(self, s, state=None, info={}):
|
|
||||||
if not isinstance(s, torch.Tensor):
|
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
|
||||||
s = s.view(batch, -1)
|
|
||||||
logits = self.model(s)
|
|
||||||
return logits, state
|
|
||||||
```
|
|
||||||
|
|
||||||
Make envs and fix seed:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
env = gym.make(task)
|
env = gym.make(task)
|
||||||
@ -180,46 +155,68 @@ state_shape = env.observation_space.shape or env.observation_space.n
|
|||||||
action_shape = env.action_space.shape or env.action_space.n
|
action_shape = env.action_space.shape or env.action_space.n
|
||||||
train_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
|
train_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
|
||||||
test_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
|
test_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
train_envs.seed(seed)
|
|
||||||
test_envs.seed(seed)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Setup policy and collector:
|
Define the network:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
net = Net(3, state_shape, action_shape, device).to(device)
|
class Net(nn.Module):
|
||||||
|
def __init__(self, state_shape, action_shape):
|
||||||
|
super().__init__()
|
||||||
|
self.model = nn.Sequential(*[
|
||||||
|
nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(128, 128), nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(128, 128), nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(128, np.prod(action_shape))
|
||||||
|
])
|
||||||
|
def forward(self, s, state=None, info={}):
|
||||||
|
if not isinstance(s, torch.Tensor):
|
||||||
|
s = torch.tensor(s, dtype=torch.float)
|
||||||
|
batch = s.shape[0]
|
||||||
|
logits = self.model(s.view(batch, -1))
|
||||||
|
return logits, state
|
||||||
|
|
||||||
|
net = Net(state_shape, action_shape)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=lr)
|
optim = torch.optim.Adam(net.parameters(), lr=lr)
|
||||||
policy = PGPolicy(net, optim, torch.distributions.Categorical, gamma)
|
```
|
||||||
train_collector = Collector(policy, train_envs, ReplayBuffer(20000))
|
|
||||||
|
Setup policy and collectors:
|
||||||
|
|
||||||
|
```python
|
||||||
|
policy = DQNPolicy(net, optim, gamma, n_step,
|
||||||
|
use_target_network=True, target_update_freq=target_freq)
|
||||||
|
train_collector = Collector(policy, train_envs, ReplayBuffer(buffer_size))
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
```
|
```
|
||||||
|
|
||||||
Let's train it:
|
Let's train it:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
result = onpolicy_trainer(policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, repeat_per_collect, test_num, batch_size, stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer)
|
result = offpolicy_trainer(
|
||||||
|
policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step,
|
||||||
|
test_num, batch_size, train_fn=lambda e: policy.set_eps(eps_train),
|
||||||
|
test_fn=lambda e: policy.set_eps(eps_test),
|
||||||
|
stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer, task=task)
|
||||||
```
|
```
|
||||||
|
|
||||||
Saving / loading trained policy (it's exactly the same as PyTorch nn.module):
|
Saving / loading trained policy (it's exactly the same as PyTorch nn.module):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
torch.save(policy.state_dict(), 'pg.pth')
|
torch.save(policy.state_dict(), 'dqn.pth')
|
||||||
policy.load_state_dict(torch.load('pg.pth', map_location=device))
|
policy.load_state_dict(torch.load('dqn.pth'))
|
||||||
```
|
```
|
||||||
|
|
||||||
Watch the performance with 35 FPS:
|
Watch the performance with 35 FPS:
|
||||||
|
|
||||||
```python3
|
```python3
|
||||||
collecter = Collector(policy, env)
|
collector = Collector(policy, env)
|
||||||
collecter.collect(n_episode=1, render=1/35)
|
collector.collect(n_episode=1, render=1/35)
|
||||||
```
|
```
|
||||||
|
|
||||||
Looking at the result saved in tensorboard: (on bash script)
|
Looking at the result saved in tensorboard: (on bash script)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
tensorboard --logdir log/pg
|
tensorboard --logdir log/dqn
|
||||||
```
|
```
|
||||||
|
|
||||||
## Citing Tianshou
|
## Citing Tianshou
|
||||||
|
Loading…
x
Reference in New Issue
Block a user