Fix Atari PPO example (#780)

- [x] I have marked all applicable categories:
    + [ ] exception-raising fix
    + [x] algorithm implementation fix
    + [ ] documentation modification
    + [ ] new feature
- [x] I have reformatted the code using `make format` (**required**)
- [x] I have checked the code using `make commit-checks` (**required**)
- [x] If applicable, I have mentioned the relevant/related issue(s)
- [x] If applicable, I have listed every items in this Pull Request
below

While trying to debug Atari PPO+LSTM, I found significant gap between
our Atari PPO example vs [CleanRL's Atari PPO w/
EnvPool](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpoolpy).
I tried to align our implementation with CleaRL's version, mostly in
hyper parameter choices, and got significant gain in Breakout, Qbert,
SpaceInvaders while on par in other games. After this fix, I would
suggest updating our [Atari
Benchmark](https://tianshou.readthedocs.io/en/master/tutorials/benchmark.html)
PPO experiments.

A few interesting findings:

- Layer initialization helps stabilize the training and enable the use
of larger learning rates; without it, larger learning rates will trigger
NaN gradient very quickly;
- ppo.py#L97-L101: this change helps training stability for reasons I do
not understand; also it makes the GPU usage higher.

Shoutout to [CleanRL](https://github.com/vwxyzjn/cleanrl) for a
well-tuned Atari PPO reference implementation!
This commit is contained in:
Yi Su 2022-12-04 12:23:18 -08:00 committed by GitHub
parent 929508ba77
commit 662af52820
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 61 additions and 34 deletions

View File

@ -114,13 +114,13 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.1 | ![](results/ppo/Pong_rew.png) | `python3 atari_ppo.py --task "PongNoFrameskip-v4"` |
| BreakoutNoFrameskip-v4 | 438.5 | ![](results/ppo/Breakout_rew.png) | `python3 atari_ppo.py --task "BreakoutNoFrameskip-v4"` |
| EnduroNoFrameskip-v4 | 1304.8 | ![](results/ppo/Enduro_rew.png) | `python3 atari_ppo.py --task "EnduroNoFrameskip-v4"` |
| QbertNoFrameskip-v4 | 13640 | ![](results/ppo/Qbert_rew.png) | `python3 atari_ppo.py --task "QbertNoFrameskip-v4"` |
| MsPacmanNoFrameskip-v4 | 1930 | ![](results/ppo/MsPacman_rew.png) | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 904 | ![](results/ppo/Seaquest_rew.png) | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4" --lr 2.5e-5` |
| SpaceInvadersNoFrameskip-v4 | 843 | ![](results/ppo/SpaceInvaders_rew.png) | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` |
| PongNoFrameskip-v4 | 20.2 | ![](results/ppo/Pong_rew.png) | `python3 atari_ppo.py --task "PongNoFrameskip-v4"` |
| BreakoutNoFrameskip-v4 | 441.8 | ![](results/ppo/Breakout_rew.png) | `python3 atari_ppo.py --task "BreakoutNoFrameskip-v4"` |
| EnduroNoFrameskip-v4 | 1245.4 | ![](results/ppo/Enduro_rew.png) | `python3 atari_ppo.py --task "EnduroNoFrameskip-v4"` |
| QbertNoFrameskip-v4 | 17395 | ![](results/ppo/Qbert_rew.png) | `python3 atari_ppo.py --task "QbertNoFrameskip-v4"` |
| MsPacmanNoFrameskip-v4 | 2098 | ![](results/ppo/MsPacman_rew.png) | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 882 | ![](results/ppo/Seaquest_rew.png) | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4" --lr 1e-4` |
| SpaceInvadersNoFrameskip-v4 | 1340.5 | ![](results/ppo/SpaceInvaders_rew.png) | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` |
# SAC (single run)

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
import numpy as np
import torch
@ -7,6 +7,29 @@ from torch import nn
from tianshou.utils.net.discrete import NoisyLinear
def layer_init(
layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0
) -> nn.Module:
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
def scale_obs(module: Type[nn.Module], denom: float = 255.0) -> Type[nn.Module]:
class scaled_module(module):
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {}
) -> Tuple[torch.Tensor, Any]:
return super().forward(obs / denom, state, info)
return scaled_module
class DQN(nn.Module):
"""Reference: Human-level control through deep reinforcement learning.
@ -23,26 +46,30 @@ class DQN(nn.Module):
device: Union[str, int, torch.device] = "cpu",
features_only: bool = False,
output_dim: Optional[int] = None,
layer_init: Callable[[nn.Module], nn.Module] = lambda x: x,
) -> None:
super().__init__()
self.device = device
self.net = nn.Sequential(
nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True),
nn.Flatten()
layer_init(nn.Conv2d(c, 32, kernel_size=8, stride=4)),
nn.ReLU(inplace=True),
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
nn.ReLU(inplace=True),
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
nn.ReLU(inplace=True), nn.Flatten()
)
with torch.no_grad():
self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])
if not features_only:
self.net = nn.Sequential(
self.net, nn.Linear(self.output_dim, 512), nn.ReLU(inplace=True),
nn.Linear(512, np.prod(action_shape))
self.net, layer_init(nn.Linear(self.output_dim, 512)),
nn.ReLU(inplace=True),
layer_init(nn.Linear(512, np.prod(action_shape)))
)
self.output_dim = np.prod(action_shape)
elif output_dim is not None:
self.net = nn.Sequential(
self.net, nn.Linear(self.output_dim, output_dim),
self.net, layer_init(nn.Linear(self.output_dim, output_dim)),
nn.ReLU(inplace=True)
)
self.output_dim = output_dim

View File

@ -5,7 +5,7 @@ import pprint
import numpy as np
import torch
from atari_network import DQN
from atari_network import DQN, layer_init, scale_obs
from atari_wrapper import make_atari_env
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
@ -22,9 +22,9 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=4213)
parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument("--scale-obs", type=int, default=1)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=5e-5)
parser.add_argument("--lr", type=float, default=2.5e-4)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000)
@ -35,14 +35,14 @@ def get_args():
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--rew-norm", type=int, default=False)
parser.add_argument("--vf-coef", type=float, default=0.5)
parser.add_argument("--vf-coef", type=float, default=0.25)
parser.add_argument("--ent-coef", type=float, default=0.01)
parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument("--lr-decay", type=int, default=True)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--eps-clip", type=float, default=0.2)
parser.add_argument("--eps-clip", type=float, default=0.1)
parser.add_argument("--dual-clip", type=float, default=None)
parser.add_argument("--value-clip", type=int, default=0)
parser.add_argument("--value-clip", type=int, default=1)
parser.add_argument("--norm-adv", type=int, default=1)
parser.add_argument("--recompute-adv", type=int, default=0)
parser.add_argument("--logdir", type=str, default="log")
@ -94,7 +94,7 @@ def test_ppo(args=get_args()):
args.seed,
args.training_num,
args.test_num,
scale=args.scale_obs,
scale=0,
frame_stack=args.frames_stack,
)
args.state_shape = env.observation_space.shape or env.observation_space.n
@ -106,16 +106,20 @@ def test_ppo(args=get_args()):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# define model
net = DQN(
net_cls = scale_obs(DQN) if args.scale_obs else DQN
net = net_cls(
*args.state_shape,
args.action_shape,
device=args.device,
features_only=True,
output_dim=args.hidden_size
output_dim=args.hidden_size,
layer_init=layer_init,
)
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
critic = Critic(net, device=args.device)
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
optim = torch.optim.Adam(
ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5
)
lr_scheduler = None
if args.lr_decay:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 142 KiB

After

Width:  |  Height:  |  Size: 189 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 146 KiB

After

Width:  |  Height:  |  Size: 173 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 152 KiB

After

Width:  |  Height:  |  Size: 223 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 115 KiB

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 145 KiB

After

Width:  |  Height:  |  Size: 196 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 137 KiB

After

Width:  |  Height:  |  Size: 154 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 159 KiB

After

Width:  |  Height:  |  Size: 181 KiB

View File

@ -79,9 +79,6 @@ class PPOPolicy(A2CPolicy):
"Dual-clip PPO parameter should greater than 1.0."
self._dual_clip = dual_clip
self._value_clip = value_clip
if not self._rew_norm:
assert not self._value_clip, \
"value clip is available only when `reward_normalization` is True"
self._norm_adv = advantage_normalization
self._recompute_adv = recompute_advantage
self._actor_critic: ActorCritic
@ -94,11 +91,8 @@ class PPOPolicy(A2CPolicy):
self._buffer, self._indices = buffer, indices
batch = self._compute_returns(batch, buffer, indices)
batch.act = to_torch_as(batch.act, batch.v_s)
old_log_prob = []
with torch.no_grad():
for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act))
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.logp_old = self(batch).dist.log_prob(batch.act)
return batch
def learn( # type: ignore
@ -113,7 +107,8 @@ class PPOPolicy(A2CPolicy):
dist = self(minibatch).dist
if self._norm_adv:
mean, std = minibatch.adv.mean(), minibatch.adv.std()
minibatch.adv = (minibatch.adv - mean) / std # per-batch norm
minibatch.adv = (minibatch.adv -
mean) / (std + self._eps) # per-batch norm
ratio = (dist.log_prob(minibatch.act) -
minibatch.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)

View File

@ -356,7 +356,8 @@ class BaseTrainer(ABC):
print(
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
f"{self.best_reward_std:.6f} in #{self.best_epoch}",
flush=True
)
if not self.is_run:
test_stat = {