Fix Atari PPO example (#780)
- [x] I have marked all applicable categories: + [ ] exception-raising fix + [x] algorithm implementation fix + [ ] documentation modification + [ ] new feature - [x] I have reformatted the code using `make format` (**required**) - [x] I have checked the code using `make commit-checks` (**required**) - [x] If applicable, I have mentioned the relevant/related issue(s) - [x] If applicable, I have listed every items in this Pull Request below While trying to debug Atari PPO+LSTM, I found significant gap between our Atari PPO example vs [CleanRL's Atari PPO w/ EnvPool](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpoolpy). I tried to align our implementation with CleaRL's version, mostly in hyper parameter choices, and got significant gain in Breakout, Qbert, SpaceInvaders while on par in other games. After this fix, I would suggest updating our [Atari Benchmark](https://tianshou.readthedocs.io/en/master/tutorials/benchmark.html) PPO experiments. A few interesting findings: - Layer initialization helps stabilize the training and enable the use of larger learning rates; without it, larger learning rates will trigger NaN gradient very quickly; - ppo.py#L97-L101: this change helps training stability for reasons I do not understand; also it makes the GPU usage higher. Shoutout to [CleanRL](https://github.com/vwxyzjn/cleanrl) for a well-tuned Atari PPO reference implementation!
@ -114,13 +114,13 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||
|
||||
| task | best reward | reward curve | parameters |
|
||||
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.1 |  | `python3 atari_ppo.py --task "PongNoFrameskip-v4"` |
|
||||
| BreakoutNoFrameskip-v4 | 438.5 |  | `python3 atari_ppo.py --task "BreakoutNoFrameskip-v4"` |
|
||||
| EnduroNoFrameskip-v4 | 1304.8 |  | `python3 atari_ppo.py --task "EnduroNoFrameskip-v4"` |
|
||||
| QbertNoFrameskip-v4 | 13640 |  | `python3 atari_ppo.py --task "QbertNoFrameskip-v4"` |
|
||||
| MsPacmanNoFrameskip-v4 | 1930 |  | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` |
|
||||
| SeaquestNoFrameskip-v4 | 904 |  | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4" --lr 2.5e-5` |
|
||||
| SpaceInvadersNoFrameskip-v4 | 843 |  | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||
| PongNoFrameskip-v4 | 20.2 |  | `python3 atari_ppo.py --task "PongNoFrameskip-v4"` |
|
||||
| BreakoutNoFrameskip-v4 | 441.8 |  | `python3 atari_ppo.py --task "BreakoutNoFrameskip-v4"` |
|
||||
| EnduroNoFrameskip-v4 | 1245.4 |  | `python3 atari_ppo.py --task "EnduroNoFrameskip-v4"` |
|
||||
| QbertNoFrameskip-v4 | 17395 |  | `python3 atari_ppo.py --task "QbertNoFrameskip-v4"` |
|
||||
| MsPacmanNoFrameskip-v4 | 2098 |  | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` |
|
||||
| SeaquestNoFrameskip-v4 | 882 |  | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4" --lr 1e-4` |
|
||||
| SpaceInvadersNoFrameskip-v4 | 1340.5 |  | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||
|
||||
# SAC (single run)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -7,6 +7,29 @@ from torch import nn
|
||||
from tianshou.utils.net.discrete import NoisyLinear
|
||||
|
||||
|
||||
def layer_init(
|
||||
layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0
|
||||
) -> nn.Module:
|
||||
torch.nn.init.orthogonal_(layer.weight, std)
|
||||
torch.nn.init.constant_(layer.bias, bias_const)
|
||||
return layer
|
||||
|
||||
|
||||
def scale_obs(module: Type[nn.Module], denom: float = 255.0) -> Type[nn.Module]:
|
||||
|
||||
class scaled_module(module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
obs: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {}
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
return super().forward(obs / denom, state, info)
|
||||
|
||||
return scaled_module
|
||||
|
||||
|
||||
class DQN(nn.Module):
|
||||
"""Reference: Human-level control through deep reinforcement learning.
|
||||
|
||||
@ -23,26 +46,30 @@ class DQN(nn.Module):
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
features_only: bool = False,
|
||||
output_dim: Optional[int] = None,
|
||||
layer_init: Callable[[nn.Module], nn.Module] = lambda x: x,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True),
|
||||
nn.Flatten()
|
||||
layer_init(nn.Conv2d(c, 32, kernel_size=8, stride=4)),
|
||||
nn.ReLU(inplace=True),
|
||||
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
|
||||
nn.ReLU(inplace=True),
|
||||
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
|
||||
nn.ReLU(inplace=True), nn.Flatten()
|
||||
)
|
||||
with torch.no_grad():
|
||||
self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])
|
||||
if not features_only:
|
||||
self.net = nn.Sequential(
|
||||
self.net, nn.Linear(self.output_dim, 512), nn.ReLU(inplace=True),
|
||||
nn.Linear(512, np.prod(action_shape))
|
||||
self.net, layer_init(nn.Linear(self.output_dim, 512)),
|
||||
nn.ReLU(inplace=True),
|
||||
layer_init(nn.Linear(512, np.prod(action_shape)))
|
||||
)
|
||||
self.output_dim = np.prod(action_shape)
|
||||
elif output_dim is not None:
|
||||
self.net = nn.Sequential(
|
||||
self.net, nn.Linear(self.output_dim, output_dim),
|
||||
self.net, layer_init(nn.Linear(self.output_dim, output_dim)),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.output_dim = output_dim
|
||||
|
@ -5,7 +5,7 @@ import pprint
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from atari_network import DQN
|
||||
from atari_network import DQN, layer_init, scale_obs
|
||||
from atari_wrapper import make_atari_env
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@ -22,9 +22,9 @@ def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
|
||||
parser.add_argument("--seed", type=int, default=4213)
|
||||
parser.add_argument("--scale-obs", type=int, default=0)
|
||||
parser.add_argument("--scale-obs", type=int, default=1)
|
||||
parser.add_argument("--buffer-size", type=int, default=100000)
|
||||
parser.add_argument("--lr", type=float, default=5e-5)
|
||||
parser.add_argument("--lr", type=float, default=2.5e-4)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--epoch", type=int, default=100)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=100000)
|
||||
@ -35,14 +35,14 @@ def get_args():
|
||||
parser.add_argument("--training-num", type=int, default=10)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
parser.add_argument("--rew-norm", type=int, default=False)
|
||||
parser.add_argument("--vf-coef", type=float, default=0.5)
|
||||
parser.add_argument("--vf-coef", type=float, default=0.25)
|
||||
parser.add_argument("--ent-coef", type=float, default=0.01)
|
||||
parser.add_argument("--gae-lambda", type=float, default=0.95)
|
||||
parser.add_argument("--lr-decay", type=int, default=True)
|
||||
parser.add_argument("--max-grad-norm", type=float, default=0.5)
|
||||
parser.add_argument("--eps-clip", type=float, default=0.2)
|
||||
parser.add_argument("--eps-clip", type=float, default=0.1)
|
||||
parser.add_argument("--dual-clip", type=float, default=None)
|
||||
parser.add_argument("--value-clip", type=int, default=0)
|
||||
parser.add_argument("--value-clip", type=int, default=1)
|
||||
parser.add_argument("--norm-adv", type=int, default=1)
|
||||
parser.add_argument("--recompute-adv", type=int, default=0)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
@ -94,7 +94,7 @@ def test_ppo(args=get_args()):
|
||||
args.seed,
|
||||
args.training_num,
|
||||
args.test_num,
|
||||
scale=args.scale_obs,
|
||||
scale=0,
|
||||
frame_stack=args.frames_stack,
|
||||
)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
@ -106,16 +106,20 @@ def test_ppo(args=get_args()):
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
# define model
|
||||
net = DQN(
|
||||
net_cls = scale_obs(DQN) if args.scale_obs else DQN
|
||||
net = net_cls(
|
||||
*args.state_shape,
|
||||
args.action_shape,
|
||||
device=args.device,
|
||||
features_only=True,
|
||||
output_dim=args.hidden_size
|
||||
output_dim=args.hidden_size,
|
||||
layer_init=layer_init,
|
||||
)
|
||||
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
|
||||
critic = Critic(net, device=args.device)
|
||||
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
|
||||
optim = torch.optim.Adam(
|
||||
ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5
|
||||
)
|
||||
|
||||
lr_scheduler = None
|
||||
if args.lr_decay:
|
||||
|
Before Width: | Height: | Size: 142 KiB After Width: | Height: | Size: 189 KiB |
Before Width: | Height: | Size: 146 KiB After Width: | Height: | Size: 173 KiB |
Before Width: | Height: | Size: 152 KiB After Width: | Height: | Size: 223 KiB |
Before Width: | Height: | Size: 115 KiB After Width: | Height: | Size: 111 KiB |
Before Width: | Height: | Size: 145 KiB After Width: | Height: | Size: 196 KiB |
Before Width: | Height: | Size: 137 KiB After Width: | Height: | Size: 154 KiB |
Before Width: | Height: | Size: 159 KiB After Width: | Height: | Size: 181 KiB |
@ -79,9 +79,6 @@ class PPOPolicy(A2CPolicy):
|
||||
"Dual-clip PPO parameter should greater than 1.0."
|
||||
self._dual_clip = dual_clip
|
||||
self._value_clip = value_clip
|
||||
if not self._rew_norm:
|
||||
assert not self._value_clip, \
|
||||
"value clip is available only when `reward_normalization` is True"
|
||||
self._norm_adv = advantage_normalization
|
||||
self._recompute_adv = recompute_advantage
|
||||
self._actor_critic: ActorCritic
|
||||
@ -94,11 +91,8 @@ class PPOPolicy(A2CPolicy):
|
||||
self._buffer, self._indices = buffer, indices
|
||||
batch = self._compute_returns(batch, buffer, indices)
|
||||
batch.act = to_torch_as(batch.act, batch.v_s)
|
||||
old_log_prob = []
|
||||
with torch.no_grad():
|
||||
for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
|
||||
old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act))
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
batch.logp_old = self(batch).dist.log_prob(batch.act)
|
||||
return batch
|
||||
|
||||
def learn( # type: ignore
|
||||
@ -113,7 +107,8 @@ class PPOPolicy(A2CPolicy):
|
||||
dist = self(minibatch).dist
|
||||
if self._norm_adv:
|
||||
mean, std = minibatch.adv.mean(), minibatch.adv.std()
|
||||
minibatch.adv = (minibatch.adv - mean) / std # per-batch norm
|
||||
minibatch.adv = (minibatch.adv -
|
||||
mean) / (std + self._eps) # per-batch norm
|
||||
ratio = (dist.log_prob(minibatch.act) -
|
||||
minibatch.logp_old).exp().float()
|
||||
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
||||
|
@ -356,7 +356,8 @@ class BaseTrainer(ABC):
|
||||
print(
|
||||
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
|
||||
f" best_reward: {self.best_reward:.6f} ± "
|
||||
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
|
||||
f"{self.best_reward_std:.6f} in #{self.best_epoch}",
|
||||
flush=True
|
||||
)
|
||||
if not self.is_run:
|
||||
test_stat = {
|
||||
|