diff --git a/README.md b/README.md index 4771378..ad7b5e0 100644 --- a/README.md +++ b/README.md @@ -115,7 +115,7 @@ You can check out the [documentation](https://tianshou.readthedocs.io) for furth ## Quick Start -This is an example of Policy Gradient. You can also run the full script under [test/discrete/test_pg.py](/test/discrete/test_pg.py). +This is an example of Deep Q Network. You can also run the full script under [test/discrete/test_dqn.py](/test/discrete/test_dqn.py). First, import the relevant packages: @@ -123,56 +123,31 @@ First, import the relevant packages: import gym, torch, numpy as np, torch.nn as nn from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import PGPolicy +from tianshou.policy import DQNPolicy from tianshou.env import SubprocVectorEnv -from tianshou.trainer import onpolicy_trainer +from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer ``` Define some hyper-parameters: ```python -task = 'CartPole-v0' -seed = 1626 -lr = 3e-4 -gamma = 0.9 -epoch = 10 -step_per_epoch = 1000 -collect_per_step = 10 -repeat_per_collect = 2 -batch_size = 64 -train_num = 8 -test_num = 100 -device = 'cuda' if torch.cuda.is_available() else 'cpu' -writer = SummaryWriter('log/pg') # tensorboard is also supported! +task = 'CartPole-v0' +lr = 1e-3 +gamma = 0.9 +n_step = 3 +eps_train, eps_test = 0.1, 0.05 +epoch = 10 +step_per_epoch = 1000 +collect_per_step = 10 +target_freq = 320 +batch_size = 64 +train_num, test_num = 8, 100 +buffer_size = 20000 +writer = SummaryWriter('log/dqn') # tensorboard is also supported! ``` -Define the network: - -```python -class Net(nn.Module): - def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'): - super().__init__() - self.device = device - self.model = [ - nn.Linear(np.prod(state_shape), 128), - nn.ReLU(inplace=True)] - for i in range(layer_num): - self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - if action_shape: - self.model += [nn.Linear(128, np.prod(action_shape))] - self.model = nn.Sequential(*self.model) - - def forward(self, s, state=None, info={}): - if not isinstance(s, torch.Tensor): - s = torch.tensor(s, device=self.device, dtype=torch.float) - batch = s.shape[0] - s = s.view(batch, -1) - logits = self.model(s) - return logits, state -``` - -Make envs and fix seed: +Make envs: ```python env = gym.make(task) @@ -180,46 +155,68 @@ state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n train_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(train_num)]) test_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(test_num)]) -np.random.seed(seed) -torch.manual_seed(seed) -train_envs.seed(seed) -test_envs.seed(seed) ``` -Setup policy and collector: +Define the network: ```python -net = Net(3, state_shape, action_shape, device).to(device) +class Net(nn.Module): + def __init__(self, state_shape, action_shape): + super().__init__() + self.model = nn.Sequential(*[ + nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True), + nn.Linear(128, 128), nn.ReLU(inplace=True), + nn.Linear(128, 128), nn.ReLU(inplace=True), + nn.Linear(128, np.prod(action_shape)) + ]) + def forward(self, s, state=None, info={}): + if not isinstance(s, torch.Tensor): + s = torch.tensor(s, dtype=torch.float) + batch = s.shape[0] + logits = self.model(s.view(batch, -1)) + return logits, state + +net = Net(state_shape, action_shape) optim = torch.optim.Adam(net.parameters(), lr=lr) -policy = PGPolicy(net, optim, torch.distributions.Categorical, gamma) -train_collector = Collector(policy, train_envs, ReplayBuffer(20000)) +``` + +Setup policy and collectors: + +```python +policy = DQNPolicy(net, optim, gamma, n_step, + use_target_network=True, target_update_freq=target_freq) +train_collector = Collector(policy, train_envs, ReplayBuffer(buffer_size)) test_collector = Collector(policy, test_envs) ``` Let's train it: ```python -result = onpolicy_trainer(policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, repeat_per_collect, test_num, batch_size, stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer) +result = offpolicy_trainer( + policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, + test_num, batch_size, train_fn=lambda e: policy.set_eps(eps_train), + test_fn=lambda e: policy.set_eps(eps_test), + stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer, task=task) ``` Saving / loading trained policy (it's exactly the same as PyTorch nn.module): ```python -torch.save(policy.state_dict(), 'pg.pth') -policy.load_state_dict(torch.load('pg.pth', map_location=device)) +torch.save(policy.state_dict(), 'dqn.pth') +policy.load_state_dict(torch.load('dqn.pth')) ``` Watch the performance with 35 FPS: ```python3 -collecter = Collector(policy, env) -collecter.collect(n_episode=1, render=1/35) +collector = Collector(policy, env) +collector.collect(n_episode=1, render=1/35) ``` Looking at the result saved in tensorboard: (on bash script) ```bash -tensorboard --logdir log/pg +tensorboard --logdir log/dqn ``` ## Citing Tianshou