Fix models using scale_obs not being persistable (due to locally defined class)

This commit is contained in:
Dominik Jain 2024-01-11 12:34:26 +01:00
parent 7fa588309b
commit 19a98c3b2a
5 changed files with 28 additions and 19 deletions

View File

@ -27,7 +27,7 @@ from tianshou.utils.logging import datetime_tag
def main( def main(
experiment_config: ExperimentConfig, experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4", task: str = "PongNoFrameskip-v4",
scale_obs: int = 0, scale_obs: bool = False,
eps_test: float = 0.005, eps_test: float = 0.005,
eps_train: float = 1.0, eps_train: float = 1.0,
eps_train_final: float = 0.05, eps_train_final: float = 0.05,

View File

@ -24,7 +24,7 @@ from tianshou.utils.logging import datetime_tag
def main( def main(
experiment_config: ExperimentConfig, experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4", task: str = "PongNoFrameskip-v4",
scale_obs: int = 0, scale_obs: bool = False,
eps_test: float = 0.005, eps_test: float = 0.005,
eps_train: float = 1.0, eps_train: float = 1.0,
eps_train_final: float = 0.05, eps_train_final: float = 0.05,

View File

@ -23,19 +23,26 @@ def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.
return layer return layer
def scale_obs(module: type[nn.Module], denom: float = 255.0) -> type[nn.Module]: class ScaledObsInputModule(torch.nn.Module):
class scaled_module(module): def __init__(self, module: torch.nn.Module, denom: float = 255.0):
def forward( super().__init__()
self, self.module = module
obs: np.ndarray | torch.Tensor, self.denom = denom
state: Any | None = None, self.output_dim = module.output_dim
info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]:
if info is None:
info = {}
return super().forward(obs / denom, state, info)
return scaled_module def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any | None = None,
info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]:
if info is None:
info = {}
return self.module.forward(obs / self.denom, state, info)
def scale_obs(module: nn.Module, denom: float = 255.0) -> nn.Module:
return ScaledObsInputModule(module, denom=denom)
class DQN(nn.Module): class DQN(nn.Module):
@ -238,8 +245,7 @@ class ActorFactoryAtariDQN(ActorFactory):
self.features_only = features_only self.features_only = features_only
def create_module(self, envs: Environments, device: TDevice) -> Actor: def create_module(self, envs: Environments, device: TDevice) -> Actor:
net_cls = scale_obs(DQN) if self.scale_obs else DQN net = DQN(
net = net_cls(
*envs.get_observation_shape(), *envs.get_observation_shape(),
envs.get_action_shape(), envs.get_action_shape(),
device=device, device=device,
@ -247,6 +253,8 @@ class ActorFactoryAtariDQN(ActorFactory):
output_dim=self.hidden_size, output_dim=self.hidden_size,
layer_init=layer_init, layer_init=layer_init,
) )
if self.scale_obs:
net = scale_obs(net)
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device) return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)

View File

@ -109,8 +109,7 @@ def test_ppo(args=get_args()):
np.random.seed(args.seed) np.random.seed(args.seed)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
# define model # define model
net_cls = scale_obs(DQN) if args.scale_obs else DQN net = DQN(
net = net_cls(
*args.state_shape, *args.state_shape,
args.action_shape, args.action_shape,
device=args.device, device=args.device,
@ -118,6 +117,8 @@ def test_ppo(args=get_args()):
output_dim=args.hidden_size, output_dim=args.hidden_size,
layer_init=layer_init, layer_init=layer_init,
) )
if args.scale_obs:
net = scale_obs(net)
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
critic = Critic(net, device=args.device) critic = Critic(net, device=args.device)
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5) optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5)

View File

@ -24,7 +24,7 @@ from tianshou.utils.logging import datetime_tag
def main( def main(
experiment_config: ExperimentConfig, experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4", task: str = "PongNoFrameskip-v4",
scale_obs: int = 0, scale_obs: bool = False,
buffer_size: int = 100000, buffer_size: int = 100000,
actor_lr: float = 1e-5, actor_lr: float = 1e-5,
critic_lr: float = 1e-5, critic_lr: float = 1e-5,