diff --git a/examples/atari/README.md b/examples/atari/README.md index 27b170c..62e5848 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -114,13 +114,13 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.1 | ![](results/ppo/Pong_rew.png) | `python3 atari_ppo.py --task "PongNoFrameskip-v4"` | -| BreakoutNoFrameskip-v4 | 438.5 | ![](results/ppo/Breakout_rew.png) | `python3 atari_ppo.py --task "BreakoutNoFrameskip-v4"` | -| EnduroNoFrameskip-v4 | 1304.8 | ![](results/ppo/Enduro_rew.png) | `python3 atari_ppo.py --task "EnduroNoFrameskip-v4"` | -| QbertNoFrameskip-v4 | 13640 | ![](results/ppo/Qbert_rew.png) | `python3 atari_ppo.py --task "QbertNoFrameskip-v4"` | -| MsPacmanNoFrameskip-v4 | 1930 | ![](results/ppo/MsPacman_rew.png) | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` | -| SeaquestNoFrameskip-v4 | 904 | ![](results/ppo/Seaquest_rew.png) | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4" --lr 2.5e-5` | -| SpaceInvadersNoFrameskip-v4 | 843 | ![](results/ppo/SpaceInvaders_rew.png) | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` | +| PongNoFrameskip-v4 | 20.2 | ![](results/ppo/Pong_rew.png) | `python3 atari_ppo.py --task "PongNoFrameskip-v4"` | +| BreakoutNoFrameskip-v4 | 441.8 | ![](results/ppo/Breakout_rew.png) | `python3 atari_ppo.py --task "BreakoutNoFrameskip-v4"` | +| EnduroNoFrameskip-v4 | 1245.4 | ![](results/ppo/Enduro_rew.png) | `python3 atari_ppo.py --task "EnduroNoFrameskip-v4"` | +| QbertNoFrameskip-v4 | 17395 | ![](results/ppo/Qbert_rew.png) | `python3 atari_ppo.py --task "QbertNoFrameskip-v4"` | +| MsPacmanNoFrameskip-v4 | 2098 | ![](results/ppo/MsPacman_rew.png) | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` | +| SeaquestNoFrameskip-v4 | 882 | ![](results/ppo/Seaquest_rew.png) | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4" --lr 1e-4` | +| SpaceInvadersNoFrameskip-v4 | 1340.5 | ![](results/ppo/SpaceInvaders_rew.png) | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` | # SAC (single run) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index ec705e7..293753d 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union import numpy as np import torch @@ -7,6 +7,29 @@ from torch import nn from tianshou.utils.net.discrete import NoisyLinear +def layer_init( + layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0 +) -> nn.Module: + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + +def scale_obs(module: Type[nn.Module], denom: float = 255.0) -> Type[nn.Module]: + + class scaled_module(module): + + def forward( + self, + obs: Union[np.ndarray, torch.Tensor], + state: Optional[Any] = None, + info: Dict[str, Any] = {} + ) -> Tuple[torch.Tensor, Any]: + return super().forward(obs / denom, state, info) + + return scaled_module + + class DQN(nn.Module): """Reference: Human-level control through deep reinforcement learning. @@ -23,26 +46,30 @@ class DQN(nn.Module): device: Union[str, int, torch.device] = "cpu", features_only: bool = False, output_dim: Optional[int] = None, + layer_init: Callable[[nn.Module], nn.Module] = lambda x: x, ) -> None: super().__init__() self.device = device self.net = nn.Sequential( - nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True), - nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True), - nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True), - nn.Flatten() + layer_init(nn.Conv2d(c, 32, kernel_size=8, stride=4)), + nn.ReLU(inplace=True), + layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)), + nn.ReLU(inplace=True), + layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)), + nn.ReLU(inplace=True), nn.Flatten() ) with torch.no_grad(): self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:]) if not features_only: self.net = nn.Sequential( - self.net, nn.Linear(self.output_dim, 512), nn.ReLU(inplace=True), - nn.Linear(512, np.prod(action_shape)) + self.net, layer_init(nn.Linear(self.output_dim, 512)), + nn.ReLU(inplace=True), + layer_init(nn.Linear(512, np.prod(action_shape))) ) self.output_dim = np.prod(action_shape) elif output_dim is not None: self.net = nn.Sequential( - self.net, nn.Linear(self.output_dim, output_dim), + self.net, layer_init(nn.Linear(self.output_dim, output_dim)), nn.ReLU(inplace=True) ) self.output_dim = output_dim diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index a9600b8..52744f5 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -5,7 +5,7 @@ import pprint import numpy as np import torch -from atari_network import DQN +from atari_network import DQN, layer_init, scale_obs from atari_wrapper import make_atari_env from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter @@ -22,9 +22,9 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=4213) - parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--scale-obs", type=int, default=1) parser.add_argument("--buffer-size", type=int, default=100000) - parser.add_argument("--lr", type=float, default=5e-5) + parser.add_argument("--lr", type=float, default=2.5e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--step-per-epoch", type=int, default=100000) @@ -35,14 +35,14 @@ def get_args(): parser.add_argument("--training-num", type=int, default=10) parser.add_argument("--test-num", type=int, default=10) parser.add_argument("--rew-norm", type=int, default=False) - parser.add_argument("--vf-coef", type=float, default=0.5) + parser.add_argument("--vf-coef", type=float, default=0.25) parser.add_argument("--ent-coef", type=float, default=0.01) parser.add_argument("--gae-lambda", type=float, default=0.95) parser.add_argument("--lr-decay", type=int, default=True) parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--eps-clip", type=float, default=0.2) + parser.add_argument("--eps-clip", type=float, default=0.1) parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=0) + parser.add_argument("--value-clip", type=int, default=1) parser.add_argument("--norm-adv", type=int, default=1) parser.add_argument("--recompute-adv", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") @@ -94,7 +94,7 @@ def test_ppo(args=get_args()): args.seed, args.training_num, args.test_num, - scale=args.scale_obs, + scale=0, frame_stack=args.frames_stack, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -106,16 +106,20 @@ def test_ppo(args=get_args()): np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - net = DQN( + net_cls = scale_obs(DQN) if args.scale_obs else DQN + net = net_cls( *args.state_shape, args.action_shape, device=args.device, features_only=True, - output_dim=args.hidden_size + output_dim=args.hidden_size, + layer_init=layer_init, ) actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) critic = Critic(net, device=args.device) - optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) + optim = torch.optim.Adam( + ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5 + ) lr_scheduler = None if args.lr_decay: diff --git a/examples/atari/results/ppo/Breakout_rew.png b/examples/atari/results/ppo/Breakout_rew.png index 296bb37..79a6cca 100644 Binary files a/examples/atari/results/ppo/Breakout_rew.png and b/examples/atari/results/ppo/Breakout_rew.png differ diff --git a/examples/atari/results/ppo/Enduro_rew.png b/examples/atari/results/ppo/Enduro_rew.png index b445ba0..621527a 100644 Binary files a/examples/atari/results/ppo/Enduro_rew.png and b/examples/atari/results/ppo/Enduro_rew.png differ diff --git a/examples/atari/results/ppo/MsPacman_rew.png b/examples/atari/results/ppo/MsPacman_rew.png index c16089d..a5cc051 100644 Binary files a/examples/atari/results/ppo/MsPacman_rew.png and b/examples/atari/results/ppo/MsPacman_rew.png differ diff --git a/examples/atari/results/ppo/Pong_rew.png b/examples/atari/results/ppo/Pong_rew.png index 62d05b2..3d6523b 100644 Binary files a/examples/atari/results/ppo/Pong_rew.png and b/examples/atari/results/ppo/Pong_rew.png differ diff --git a/examples/atari/results/ppo/Qbert_rew.png b/examples/atari/results/ppo/Qbert_rew.png index 8db8b67..db0bfd2 100644 Binary files a/examples/atari/results/ppo/Qbert_rew.png and b/examples/atari/results/ppo/Qbert_rew.png differ diff --git a/examples/atari/results/ppo/Seaquest_rew.png b/examples/atari/results/ppo/Seaquest_rew.png index 200a68e..4896a9c 100644 Binary files a/examples/atari/results/ppo/Seaquest_rew.png and b/examples/atari/results/ppo/Seaquest_rew.png differ diff --git a/examples/atari/results/ppo/SpaceInvaders_rew.png b/examples/atari/results/ppo/SpaceInvaders_rew.png index 93a521e..fb67223 100644 Binary files a/examples/atari/results/ppo/SpaceInvaders_rew.png and b/examples/atari/results/ppo/SpaceInvaders_rew.png differ diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 3c19daf..b8c1914 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -79,9 +79,6 @@ class PPOPolicy(A2CPolicy): "Dual-clip PPO parameter should greater than 1.0." self._dual_clip = dual_clip self._value_clip = value_clip - if not self._rew_norm: - assert not self._value_clip, \ - "value clip is available only when `reward_normalization` is True" self._norm_adv = advantage_normalization self._recompute_adv = recompute_advantage self._actor_critic: ActorCritic @@ -94,11 +91,8 @@ class PPOPolicy(A2CPolicy): self._buffer, self._indices = buffer, indices batch = self._compute_returns(batch, buffer, indices) batch.act = to_torch_as(batch.act, batch.v_s) - old_log_prob = [] with torch.no_grad(): - for minibatch in batch.split(self._batch, shuffle=False, merge_last=True): - old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act)) - batch.logp_old = torch.cat(old_log_prob, dim=0) + batch.logp_old = self(batch).dist.log_prob(batch.act) return batch def learn( # type: ignore @@ -113,7 +107,8 @@ class PPOPolicy(A2CPolicy): dist = self(minibatch).dist if self._norm_adv: mean, std = minibatch.adv.mean(), minibatch.adv.std() - minibatch.adv = (minibatch.adv - mean) / std # per-batch norm + minibatch.adv = (minibatch.adv - + mean) / (std + self._eps) # per-batch norm ratio = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index a9fd89e..cdab009 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -356,7 +356,8 @@ class BaseTrainer(ABC): print( f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," f" best_reward: {self.best_reward:.6f} ± " - f"{self.best_reward_std:.6f} in #{self.best_epoch}" + f"{self.best_reward_std:.6f} in #{self.best_epoch}", + flush=True ) if not self.is_run: test_stat = {