Fix Atari PPO example ()

- [x] I have marked all applicable categories:
    + [ ] exception-raising fix
    + [x] algorithm implementation fix
    + [ ] documentation modification
    + [ ] new feature
- [x] I have reformatted the code using `make format` (**required**)
- [x] I have checked the code using `make commit-checks` (**required**)
- [x] If applicable, I have mentioned the relevant/related issue(s)
- [x] If applicable, I have listed every items in this Pull Request
below

While trying to debug Atari PPO+LSTM, I found significant gap between
our Atari PPO example vs [CleanRL's Atari PPO w/
EnvPool](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpoolpy).
I tried to align our implementation with CleaRL's version, mostly in
hyper parameter choices, and got significant gain in Breakout, Qbert,
SpaceInvaders while on par in other games. After this fix, I would
suggest updating our [Atari
Benchmark](https://tianshou.readthedocs.io/en/master/tutorials/benchmark.html)
PPO experiments.

A few interesting findings:

- Layer initialization helps stabilize the training and enable the use
of larger learning rates; without it, larger learning rates will trigger
NaN gradient very quickly;
- ppo.py#L97-L101: this change helps training stability for reasons I do
not understand; also it makes the GPU usage higher.

Shoutout to [CleanRL](https://github.com/vwxyzjn/cleanrl) for a
well-tuned Atari PPO reference implementation!
This commit is contained in:
Yi Su 2022-12-04 12:23:18 -08:00 committed by GitHub
parent 929508ba77
commit 662af52820
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 61 additions and 34 deletions

@ -114,13 +114,13 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| task | best reward | reward curve | parameters | | task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.1 | ![](results/ppo/Pong_rew.png) | `python3 atari_ppo.py --task "PongNoFrameskip-v4"` | | PongNoFrameskip-v4 | 20.2 | ![](results/ppo/Pong_rew.png) | `python3 atari_ppo.py --task "PongNoFrameskip-v4"` |
| BreakoutNoFrameskip-v4 | 438.5 | ![](results/ppo/Breakout_rew.png) | `python3 atari_ppo.py --task "BreakoutNoFrameskip-v4"` | | BreakoutNoFrameskip-v4 | 441.8 | ![](results/ppo/Breakout_rew.png) | `python3 atari_ppo.py --task "BreakoutNoFrameskip-v4"` |
| EnduroNoFrameskip-v4 | 1304.8 | ![](results/ppo/Enduro_rew.png) | `python3 atari_ppo.py --task "EnduroNoFrameskip-v4"` | | EnduroNoFrameskip-v4 | 1245.4 | ![](results/ppo/Enduro_rew.png) | `python3 atari_ppo.py --task "EnduroNoFrameskip-v4"` |
| QbertNoFrameskip-v4 | 13640 | ![](results/ppo/Qbert_rew.png) | `python3 atari_ppo.py --task "QbertNoFrameskip-v4"` | | QbertNoFrameskip-v4 | 17395 | ![](results/ppo/Qbert_rew.png) | `python3 atari_ppo.py --task "QbertNoFrameskip-v4"` |
| MsPacmanNoFrameskip-v4 | 1930 | ![](results/ppo/MsPacman_rew.png) | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` | | MsPacmanNoFrameskip-v4 | 2098 | ![](results/ppo/MsPacman_rew.png) | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 904 | ![](results/ppo/Seaquest_rew.png) | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4" --lr 2.5e-5` | | SeaquestNoFrameskip-v4 | 882 | ![](results/ppo/Seaquest_rew.png) | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4" --lr 1e-4` |
| SpaceInvadersNoFrameskip-v4 | 843 | ![](results/ppo/SpaceInvaders_rew.png) | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` | | SpaceInvadersNoFrameskip-v4 | 1340.5 | ![](results/ppo/SpaceInvaders_rew.png) | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` |
# SAC (single run) # SAC (single run)

@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Union from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
import numpy as np import numpy as np
import torch import torch
@ -7,6 +7,29 @@ from torch import nn
from tianshou.utils.net.discrete import NoisyLinear from tianshou.utils.net.discrete import NoisyLinear
def layer_init(
layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0
) -> nn.Module:
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
def scale_obs(module: Type[nn.Module], denom: float = 255.0) -> Type[nn.Module]:
class scaled_module(module):
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {}
) -> Tuple[torch.Tensor, Any]:
return super().forward(obs / denom, state, info)
return scaled_module
class DQN(nn.Module): class DQN(nn.Module):
"""Reference: Human-level control through deep reinforcement learning. """Reference: Human-level control through deep reinforcement learning.
@ -23,26 +46,30 @@ class DQN(nn.Module):
device: Union[str, int, torch.device] = "cpu", device: Union[str, int, torch.device] = "cpu",
features_only: bool = False, features_only: bool = False,
output_dim: Optional[int] = None, output_dim: Optional[int] = None,
layer_init: Callable[[nn.Module], nn.Module] = lambda x: x,
) -> None: ) -> None:
super().__init__() super().__init__()
self.device = device self.device = device
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True), layer_init(nn.Conv2d(c, 32, kernel_size=8, stride=4)),
nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True), layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
nn.Flatten() nn.ReLU(inplace=True),
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
nn.ReLU(inplace=True), nn.Flatten()
) )
with torch.no_grad(): with torch.no_grad():
self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:]) self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])
if not features_only: if not features_only:
self.net = nn.Sequential( self.net = nn.Sequential(
self.net, nn.Linear(self.output_dim, 512), nn.ReLU(inplace=True), self.net, layer_init(nn.Linear(self.output_dim, 512)),
nn.Linear(512, np.prod(action_shape)) nn.ReLU(inplace=True),
layer_init(nn.Linear(512, np.prod(action_shape)))
) )
self.output_dim = np.prod(action_shape) self.output_dim = np.prod(action_shape)
elif output_dim is not None: elif output_dim is not None:
self.net = nn.Sequential( self.net = nn.Sequential(
self.net, nn.Linear(self.output_dim, output_dim), self.net, layer_init(nn.Linear(self.output_dim, output_dim)),
nn.ReLU(inplace=True) nn.ReLU(inplace=True)
) )
self.output_dim = output_dim self.output_dim = output_dim

@ -5,7 +5,7 @@ import pprint
import numpy as np import numpy as np
import torch import torch
from atari_network import DQN from atari_network import DQN, layer_init, scale_obs
from atari_wrapper import make_atari_env from atari_wrapper import make_atari_env
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -22,9 +22,9 @@ def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=4213) parser.add_argument("--seed", type=int, default=4213)
parser.add_argument("--scale-obs", type=int, default=0) parser.add_argument("--scale-obs", type=int, default=1)
parser.add_argument("--buffer-size", type=int, default=100000) parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=5e-5) parser.add_argument("--lr", type=float, default=2.5e-4)
parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000) parser.add_argument("--step-per-epoch", type=int, default=100000)
@ -35,14 +35,14 @@ def get_args():
parser.add_argument("--training-num", type=int, default=10) parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10) parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--rew-norm", type=int, default=False) parser.add_argument("--rew-norm", type=int, default=False)
parser.add_argument("--vf-coef", type=float, default=0.5) parser.add_argument("--vf-coef", type=float, default=0.25)
parser.add_argument("--ent-coef", type=float, default=0.01) parser.add_argument("--ent-coef", type=float, default=0.01)
parser.add_argument("--gae-lambda", type=float, default=0.95) parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument("--lr-decay", type=int, default=True) parser.add_argument("--lr-decay", type=int, default=True)
parser.add_argument("--max-grad-norm", type=float, default=0.5) parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--eps-clip", type=float, default=0.2) parser.add_argument("--eps-clip", type=float, default=0.1)
parser.add_argument("--dual-clip", type=float, default=None) parser.add_argument("--dual-clip", type=float, default=None)
parser.add_argument("--value-clip", type=int, default=0) parser.add_argument("--value-clip", type=int, default=1)
parser.add_argument("--norm-adv", type=int, default=1) parser.add_argument("--norm-adv", type=int, default=1)
parser.add_argument("--recompute-adv", type=int, default=0) parser.add_argument("--recompute-adv", type=int, default=0)
parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--logdir", type=str, default="log")
@ -94,7 +94,7 @@ def test_ppo(args=get_args()):
args.seed, args.seed,
args.training_num, args.training_num,
args.test_num, args.test_num,
scale=args.scale_obs, scale=0,
frame_stack=args.frames_stack, frame_stack=args.frames_stack,
) )
args.state_shape = env.observation_space.shape or env.observation_space.n args.state_shape = env.observation_space.shape or env.observation_space.n
@ -106,16 +106,20 @@ def test_ppo(args=get_args()):
np.random.seed(args.seed) np.random.seed(args.seed)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
# define model # define model
net = DQN( net_cls = scale_obs(DQN) if args.scale_obs else DQN
net = net_cls(
*args.state_shape, *args.state_shape,
args.action_shape, args.action_shape,
device=args.device, device=args.device,
features_only=True, features_only=True,
output_dim=args.hidden_size output_dim=args.hidden_size,
layer_init=layer_init,
) )
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
critic = Critic(net, device=args.device) critic = Critic(net, device=args.device)
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) optim = torch.optim.Adam(
ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5
)
lr_scheduler = None lr_scheduler = None
if args.lr_decay: if args.lr_decay:

Binary file not shown.

Before

(image error) Size: 142 KiB

After

(image error) Size: 189 KiB

Binary file not shown.

Before

(image error) Size: 146 KiB

After

(image error) Size: 173 KiB

Binary file not shown.

Before

(image error) Size: 152 KiB

After

(image error) Size: 223 KiB

Binary file not shown.

Before

(image error) Size: 115 KiB

After

(image error) Size: 111 KiB

Binary file not shown.

Before

(image error) Size: 145 KiB

After

(image error) Size: 196 KiB

Binary file not shown.

Before

(image error) Size: 137 KiB

After

(image error) Size: 154 KiB

Binary file not shown.

Before

(image error) Size: 159 KiB

After

(image error) Size: 181 KiB

@ -79,9 +79,6 @@ class PPOPolicy(A2CPolicy):
"Dual-clip PPO parameter should greater than 1.0." "Dual-clip PPO parameter should greater than 1.0."
self._dual_clip = dual_clip self._dual_clip = dual_clip
self._value_clip = value_clip self._value_clip = value_clip
if not self._rew_norm:
assert not self._value_clip, \
"value clip is available only when `reward_normalization` is True"
self._norm_adv = advantage_normalization self._norm_adv = advantage_normalization
self._recompute_adv = recompute_advantage self._recompute_adv = recompute_advantage
self._actor_critic: ActorCritic self._actor_critic: ActorCritic
@ -94,11 +91,8 @@ class PPOPolicy(A2CPolicy):
self._buffer, self._indices = buffer, indices self._buffer, self._indices = buffer, indices
batch = self._compute_returns(batch, buffer, indices) batch = self._compute_returns(batch, buffer, indices)
batch.act = to_torch_as(batch.act, batch.v_s) batch.act = to_torch_as(batch.act, batch.v_s)
old_log_prob = []
with torch.no_grad(): with torch.no_grad():
for minibatch in batch.split(self._batch, shuffle=False, merge_last=True): batch.logp_old = self(batch).dist.log_prob(batch.act)
old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act))
batch.logp_old = torch.cat(old_log_prob, dim=0)
return batch return batch
def learn( # type: ignore def learn( # type: ignore
@ -113,7 +107,8 @@ class PPOPolicy(A2CPolicy):
dist = self(minibatch).dist dist = self(minibatch).dist
if self._norm_adv: if self._norm_adv:
mean, std = minibatch.adv.mean(), minibatch.adv.std() mean, std = minibatch.adv.mean(), minibatch.adv.std()
minibatch.adv = (minibatch.adv - mean) / std # per-batch norm minibatch.adv = (minibatch.adv -
mean) / (std + self._eps) # per-batch norm
ratio = (dist.log_prob(minibatch.act) - ratio = (dist.log_prob(minibatch.act) -
minibatch.logp_old).exp().float() minibatch.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)

@ -356,7 +356,8 @@ class BaseTrainer(ABC):
print( print(
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
f" best_reward: {self.best_reward:.6f} ± " f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f} in #{self.best_epoch}" f"{self.best_reward_std:.6f} in #{self.best_epoch}",
flush=True
) )
if not self.is_run: if not self.is_run:
test_stat = { test_stat = {