From 19a98c3b2affe9a3ee38f365aeb50c199e876863 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 11 Jan 2024 12:34:26 +0100 Subject: [PATCH] Fix models using scale_obs not being persistable (due to locally defined class) --- examples/atari/atari_dqn_hl.py | 2 +- examples/atari/atari_iqn_hl.py | 2 +- examples/atari/atari_network.py | 36 ++++++++++++++++++++------------- examples/atari/atari_ppo.py | 5 +++-- examples/atari/atari_sac_hl.py | 2 +- 5 files changed, 28 insertions(+), 19 deletions(-) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index f7aa331..0c435ee 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -27,7 +27,7 @@ from tianshou.utils.logging import datetime_tag def main( experiment_config: ExperimentConfig, task: str = "PongNoFrameskip-v4", - scale_obs: int = 0, + scale_obs: bool = False, eps_test: float = 0.005, eps_train: float = 1.0, eps_train_final: float = 0.05, diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 2cd709c..8f94ece 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -24,7 +24,7 @@ from tianshou.utils.logging import datetime_tag def main( experiment_config: ExperimentConfig, task: str = "PongNoFrameskip-v4", - scale_obs: int = 0, + scale_obs: bool = False, eps_test: float = 0.005, eps_train: float = 1.0, eps_train_final: float = 0.05, diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 0ff9b50..2e49478 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -23,19 +23,26 @@ def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0. return layer -def scale_obs(module: type[nn.Module], denom: float = 255.0) -> type[nn.Module]: - class scaled_module(module): - def forward( - self, - obs: np.ndarray | torch.Tensor, - state: Any | None = None, - info: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, Any]: - if info is None: - info = {} - return super().forward(obs / denom, state, info) +class ScaledObsInputModule(torch.nn.Module): + def __init__(self, module: torch.nn.Module, denom: float = 255.0): + super().__init__() + self.module = module + self.denom = denom + self.output_dim = module.output_dim - return scaled_module + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: Any | None = None, + info: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, Any]: + if info is None: + info = {} + return self.module.forward(obs / self.denom, state, info) + + +def scale_obs(module: nn.Module, denom: float = 255.0) -> nn.Module: + return ScaledObsInputModule(module, denom=denom) class DQN(nn.Module): @@ -238,8 +245,7 @@ class ActorFactoryAtariDQN(ActorFactory): self.features_only = features_only def create_module(self, envs: Environments, device: TDevice) -> Actor: - net_cls = scale_obs(DQN) if self.scale_obs else DQN - net = net_cls( + net = DQN( *envs.get_observation_shape(), envs.get_action_shape(), device=device, @@ -247,6 +253,8 @@ class ActorFactoryAtariDQN(ActorFactory): output_dim=self.hidden_size, layer_init=layer_init, ) + if self.scale_obs: + net = scale_obs(net) return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device) diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 1dd6956..63da7a2 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -109,8 +109,7 @@ def test_ppo(args=get_args()): np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - net_cls = scale_obs(DQN) if args.scale_obs else DQN - net = net_cls( + net = DQN( *args.state_shape, args.action_shape, device=args.device, @@ -118,6 +117,8 @@ def test_ppo(args=get_args()): output_dim=args.hidden_size, layer_init=layer_init, ) + if args.scale_obs: + net = scale_obs(net) actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) critic = Critic(net, device=args.device) optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 8314bee..f1fd8c4 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -24,7 +24,7 @@ from tianshou.utils.logging import datetime_tag def main( experiment_config: ExperimentConfig, task: str = "PongNoFrameskip-v4", - scale_obs: int = 0, + scale_obs: bool = False, buffer_size: int = 100000, actor_lr: float = 1e-5, critic_lr: float = 1e-5,