Fix Atari PPO example (#780)
- [x] I have marked all applicable categories: + [ ] exception-raising fix + [x] algorithm implementation fix + [ ] documentation modification + [ ] new feature - [x] I have reformatted the code using `make format` (**required**) - [x] I have checked the code using `make commit-checks` (**required**) - [x] If applicable, I have mentioned the relevant/related issue(s) - [x] If applicable, I have listed every items in this Pull Request below While trying to debug Atari PPO+LSTM, I found significant gap between our Atari PPO example vs [CleanRL's Atari PPO w/ EnvPool](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpoolpy). I tried to align our implementation with CleaRL's version, mostly in hyper parameter choices, and got significant gain in Breakout, Qbert, SpaceInvaders while on par in other games. After this fix, I would suggest updating our [Atari Benchmark](https://tianshou.readthedocs.io/en/master/tutorials/benchmark.html) PPO experiments. A few interesting findings: - Layer initialization helps stabilize the training and enable the use of larger learning rates; without it, larger learning rates will trigger NaN gradient very quickly; - ppo.py#L97-L101: this change helps training stability for reasons I do not understand; also it makes the GPU usage higher. Shoutout to [CleanRL](https://github.com/vwxyzjn/cleanrl) for a well-tuned Atari PPO reference implementation!
@ -114,13 +114,13 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
|||||||
|
|
||||||
| task | best reward | reward curve | parameters |
|
| task | best reward | reward curve | parameters |
|
||||||
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
|
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
|
||||||
| PongNoFrameskip-v4 | 20.1 |  | `python3 atari_ppo.py --task "PongNoFrameskip-v4"` |
|
| PongNoFrameskip-v4 | 20.2 |  | `python3 atari_ppo.py --task "PongNoFrameskip-v4"` |
|
||||||
| BreakoutNoFrameskip-v4 | 438.5 |  | `python3 atari_ppo.py --task "BreakoutNoFrameskip-v4"` |
|
| BreakoutNoFrameskip-v4 | 441.8 |  | `python3 atari_ppo.py --task "BreakoutNoFrameskip-v4"` |
|
||||||
| EnduroNoFrameskip-v4 | 1304.8 |  | `python3 atari_ppo.py --task "EnduroNoFrameskip-v4"` |
|
| EnduroNoFrameskip-v4 | 1245.4 |  | `python3 atari_ppo.py --task "EnduroNoFrameskip-v4"` |
|
||||||
| QbertNoFrameskip-v4 | 13640 |  | `python3 atari_ppo.py --task "QbertNoFrameskip-v4"` |
|
| QbertNoFrameskip-v4 | 17395 |  | `python3 atari_ppo.py --task "QbertNoFrameskip-v4"` |
|
||||||
| MsPacmanNoFrameskip-v4 | 1930 |  | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` |
|
| MsPacmanNoFrameskip-v4 | 2098 |  | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` |
|
||||||
| SeaquestNoFrameskip-v4 | 904 |  | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4" --lr 2.5e-5` |
|
| SeaquestNoFrameskip-v4 | 882 |  | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4" --lr 1e-4` |
|
||||||
| SpaceInvadersNoFrameskip-v4 | 843 |  | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` |
|
| SpaceInvadersNoFrameskip-v4 | 1340.5 |  | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||||
|
|
||||||
# SAC (single run)
|
# SAC (single run)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -7,6 +7,29 @@ from torch import nn
|
|||||||
from tianshou.utils.net.discrete import NoisyLinear
|
from tianshou.utils.net.discrete import NoisyLinear
|
||||||
|
|
||||||
|
|
||||||
|
def layer_init(
|
||||||
|
layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0
|
||||||
|
) -> nn.Module:
|
||||||
|
torch.nn.init.orthogonal_(layer.weight, std)
|
||||||
|
torch.nn.init.constant_(layer.bias, bias_const)
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
def scale_obs(module: Type[nn.Module], denom: float = 255.0) -> Type[nn.Module]:
|
||||||
|
|
||||||
|
class scaled_module(module):
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
obs: Union[np.ndarray, torch.Tensor],
|
||||||
|
state: Optional[Any] = None,
|
||||||
|
info: Dict[str, Any] = {}
|
||||||
|
) -> Tuple[torch.Tensor, Any]:
|
||||||
|
return super().forward(obs / denom, state, info)
|
||||||
|
|
||||||
|
return scaled_module
|
||||||
|
|
||||||
|
|
||||||
class DQN(nn.Module):
|
class DQN(nn.Module):
|
||||||
"""Reference: Human-level control through deep reinforcement learning.
|
"""Reference: Human-level control through deep reinforcement learning.
|
||||||
|
|
||||||
@ -23,26 +46,30 @@ class DQN(nn.Module):
|
|||||||
device: Union[str, int, torch.device] = "cpu",
|
device: Union[str, int, torch.device] = "cpu",
|
||||||
features_only: bool = False,
|
features_only: bool = False,
|
||||||
output_dim: Optional[int] = None,
|
output_dim: Optional[int] = None,
|
||||||
|
layer_init: Callable[[nn.Module], nn.Module] = lambda x: x,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True),
|
layer_init(nn.Conv2d(c, 32, kernel_size=8, stride=4)),
|
||||||
nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True),
|
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
|
||||||
nn.Flatten()
|
nn.ReLU(inplace=True),
|
||||||
|
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
|
||||||
|
nn.ReLU(inplace=True), nn.Flatten()
|
||||||
)
|
)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])
|
self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])
|
||||||
if not features_only:
|
if not features_only:
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
self.net, nn.Linear(self.output_dim, 512), nn.ReLU(inplace=True),
|
self.net, layer_init(nn.Linear(self.output_dim, 512)),
|
||||||
nn.Linear(512, np.prod(action_shape))
|
nn.ReLU(inplace=True),
|
||||||
|
layer_init(nn.Linear(512, np.prod(action_shape)))
|
||||||
)
|
)
|
||||||
self.output_dim = np.prod(action_shape)
|
self.output_dim = np.prod(action_shape)
|
||||||
elif output_dim is not None:
|
elif output_dim is not None:
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
self.net, nn.Linear(self.output_dim, output_dim),
|
self.net, layer_init(nn.Linear(self.output_dim, output_dim)),
|
||||||
nn.ReLU(inplace=True)
|
nn.ReLU(inplace=True)
|
||||||
)
|
)
|
||||||
self.output_dim = output_dim
|
self.output_dim = output_dim
|
||||||
|
@ -5,7 +5,7 @@ import pprint
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from atari_network import DQN
|
from atari_network import DQN, layer_init, scale_obs
|
||||||
from atari_wrapper import make_atari_env
|
from atari_wrapper import make_atari_env
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
@ -22,9 +22,9 @@ def get_args():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
|
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
|
||||||
parser.add_argument("--seed", type=int, default=4213)
|
parser.add_argument("--seed", type=int, default=4213)
|
||||||
parser.add_argument("--scale-obs", type=int, default=0)
|
parser.add_argument("--scale-obs", type=int, default=1)
|
||||||
parser.add_argument("--buffer-size", type=int, default=100000)
|
parser.add_argument("--buffer-size", type=int, default=100000)
|
||||||
parser.add_argument("--lr", type=float, default=5e-5)
|
parser.add_argument("--lr", type=float, default=2.5e-4)
|
||||||
parser.add_argument("--gamma", type=float, default=0.99)
|
parser.add_argument("--gamma", type=float, default=0.99)
|
||||||
parser.add_argument("--epoch", type=int, default=100)
|
parser.add_argument("--epoch", type=int, default=100)
|
||||||
parser.add_argument("--step-per-epoch", type=int, default=100000)
|
parser.add_argument("--step-per-epoch", type=int, default=100000)
|
||||||
@ -35,14 +35,14 @@ def get_args():
|
|||||||
parser.add_argument("--training-num", type=int, default=10)
|
parser.add_argument("--training-num", type=int, default=10)
|
||||||
parser.add_argument("--test-num", type=int, default=10)
|
parser.add_argument("--test-num", type=int, default=10)
|
||||||
parser.add_argument("--rew-norm", type=int, default=False)
|
parser.add_argument("--rew-norm", type=int, default=False)
|
||||||
parser.add_argument("--vf-coef", type=float, default=0.5)
|
parser.add_argument("--vf-coef", type=float, default=0.25)
|
||||||
parser.add_argument("--ent-coef", type=float, default=0.01)
|
parser.add_argument("--ent-coef", type=float, default=0.01)
|
||||||
parser.add_argument("--gae-lambda", type=float, default=0.95)
|
parser.add_argument("--gae-lambda", type=float, default=0.95)
|
||||||
parser.add_argument("--lr-decay", type=int, default=True)
|
parser.add_argument("--lr-decay", type=int, default=True)
|
||||||
parser.add_argument("--max-grad-norm", type=float, default=0.5)
|
parser.add_argument("--max-grad-norm", type=float, default=0.5)
|
||||||
parser.add_argument("--eps-clip", type=float, default=0.2)
|
parser.add_argument("--eps-clip", type=float, default=0.1)
|
||||||
parser.add_argument("--dual-clip", type=float, default=None)
|
parser.add_argument("--dual-clip", type=float, default=None)
|
||||||
parser.add_argument("--value-clip", type=int, default=0)
|
parser.add_argument("--value-clip", type=int, default=1)
|
||||||
parser.add_argument("--norm-adv", type=int, default=1)
|
parser.add_argument("--norm-adv", type=int, default=1)
|
||||||
parser.add_argument("--recompute-adv", type=int, default=0)
|
parser.add_argument("--recompute-adv", type=int, default=0)
|
||||||
parser.add_argument("--logdir", type=str, default="log")
|
parser.add_argument("--logdir", type=str, default="log")
|
||||||
@ -94,7 +94,7 @@ def test_ppo(args=get_args()):
|
|||||||
args.seed,
|
args.seed,
|
||||||
args.training_num,
|
args.training_num,
|
||||||
args.test_num,
|
args.test_num,
|
||||||
scale=args.scale_obs,
|
scale=0,
|
||||||
frame_stack=args.frames_stack,
|
frame_stack=args.frames_stack,
|
||||||
)
|
)
|
||||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
@ -106,16 +106,20 @@ def test_ppo(args=get_args()):
|
|||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
# define model
|
# define model
|
||||||
net = DQN(
|
net_cls = scale_obs(DQN) if args.scale_obs else DQN
|
||||||
|
net = net_cls(
|
||||||
*args.state_shape,
|
*args.state_shape,
|
||||||
args.action_shape,
|
args.action_shape,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
features_only=True,
|
features_only=True,
|
||||||
output_dim=args.hidden_size
|
output_dim=args.hidden_size,
|
||||||
|
layer_init=layer_init,
|
||||||
)
|
)
|
||||||
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
|
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
|
||||||
critic = Critic(net, device=args.device)
|
critic = Critic(net, device=args.device)
|
||||||
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
|
optim = torch.optim.Adam(
|
||||||
|
ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5
|
||||||
|
)
|
||||||
|
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
if args.lr_decay:
|
if args.lr_decay:
|
||||||
|
Before ![]() (image error) Size: 142 KiB After ![]() (image error) Size: 189 KiB ![]() ![]() |
Before ![]() (image error) Size: 146 KiB After ![]() (image error) Size: 173 KiB ![]() ![]() |
Before ![]() (image error) Size: 152 KiB After ![]() (image error) Size: 223 KiB ![]() ![]() |
Before ![]() (image error) Size: 115 KiB After ![]() (image error) Size: 111 KiB ![]() ![]() |
Before ![]() (image error) Size: 145 KiB After ![]() (image error) Size: 196 KiB ![]() ![]() |
Before ![]() (image error) Size: 137 KiB After ![]() (image error) Size: 154 KiB ![]() ![]() |
Before ![]() (image error) Size: 159 KiB After ![]() (image error) Size: 181 KiB ![]() ![]() |
@ -79,9 +79,6 @@ class PPOPolicy(A2CPolicy):
|
|||||||
"Dual-clip PPO parameter should greater than 1.0."
|
"Dual-clip PPO parameter should greater than 1.0."
|
||||||
self._dual_clip = dual_clip
|
self._dual_clip = dual_clip
|
||||||
self._value_clip = value_clip
|
self._value_clip = value_clip
|
||||||
if not self._rew_norm:
|
|
||||||
assert not self._value_clip, \
|
|
||||||
"value clip is available only when `reward_normalization` is True"
|
|
||||||
self._norm_adv = advantage_normalization
|
self._norm_adv = advantage_normalization
|
||||||
self._recompute_adv = recompute_advantage
|
self._recompute_adv = recompute_advantage
|
||||||
self._actor_critic: ActorCritic
|
self._actor_critic: ActorCritic
|
||||||
@ -94,11 +91,8 @@ class PPOPolicy(A2CPolicy):
|
|||||||
self._buffer, self._indices = buffer, indices
|
self._buffer, self._indices = buffer, indices
|
||||||
batch = self._compute_returns(batch, buffer, indices)
|
batch = self._compute_returns(batch, buffer, indices)
|
||||||
batch.act = to_torch_as(batch.act, batch.v_s)
|
batch.act = to_torch_as(batch.act, batch.v_s)
|
||||||
old_log_prob = []
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
|
batch.logp_old = self(batch).dist.log_prob(batch.act)
|
||||||
old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act))
|
|
||||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def learn( # type: ignore
|
def learn( # type: ignore
|
||||||
@ -113,7 +107,8 @@ class PPOPolicy(A2CPolicy):
|
|||||||
dist = self(minibatch).dist
|
dist = self(minibatch).dist
|
||||||
if self._norm_adv:
|
if self._norm_adv:
|
||||||
mean, std = minibatch.adv.mean(), minibatch.adv.std()
|
mean, std = minibatch.adv.mean(), minibatch.adv.std()
|
||||||
minibatch.adv = (minibatch.adv - mean) / std # per-batch norm
|
minibatch.adv = (minibatch.adv -
|
||||||
|
mean) / (std + self._eps) # per-batch norm
|
||||||
ratio = (dist.log_prob(minibatch.act) -
|
ratio = (dist.log_prob(minibatch.act) -
|
||||||
minibatch.logp_old).exp().float()
|
minibatch.logp_old).exp().float()
|
||||||
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
||||||
|
@ -356,7 +356,8 @@ class BaseTrainer(ABC):
|
|||||||
print(
|
print(
|
||||||
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
|
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
|
||||||
f" best_reward: {self.best_reward:.6f} ± "
|
f" best_reward: {self.best_reward:.6f} ± "
|
||||||
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
|
f"{self.best_reward_std:.6f} in #{self.best_epoch}",
|
||||||
|
flush=True
|
||||||
)
|
)
|
||||||
if not self.is_run:
|
if not self.is_run:
|
||||||
test_stat = {
|
test_stat = {
|
||||||
|