Fix models using scale_obs not being persistable (due to locally defined class)

This commit is contained in:
Dominik Jain 2024-01-11 12:34:26 +01:00
parent 7fa588309b
commit 19a98c3b2a
5 changed files with 28 additions and 19 deletions

View File

@ -27,7 +27,7 @@ from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: int = 0,
scale_obs: bool = False,
eps_test: float = 0.005,
eps_train: float = 1.0,
eps_train_final: float = 0.05,

View File

@ -24,7 +24,7 @@ from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: int = 0,
scale_obs: bool = False,
eps_test: float = 0.005,
eps_train: float = 1.0,
eps_train_final: float = 0.05,

View File

@ -23,19 +23,26 @@ def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.
return layer
def scale_obs(module: type[nn.Module], denom: float = 255.0) -> type[nn.Module]:
class scaled_module(module):
def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any | None = None,
info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]:
if info is None:
info = {}
return super().forward(obs / denom, state, info)
class ScaledObsInputModule(torch.nn.Module):
def __init__(self, module: torch.nn.Module, denom: float = 255.0):
super().__init__()
self.module = module
self.denom = denom
self.output_dim = module.output_dim
return scaled_module
def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any | None = None,
info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]:
if info is None:
info = {}
return self.module.forward(obs / self.denom, state, info)
def scale_obs(module: nn.Module, denom: float = 255.0) -> nn.Module:
return ScaledObsInputModule(module, denom=denom)
class DQN(nn.Module):
@ -238,8 +245,7 @@ class ActorFactoryAtariDQN(ActorFactory):
self.features_only = features_only
def create_module(self, envs: Environments, device: TDevice) -> Actor:
net_cls = scale_obs(DQN) if self.scale_obs else DQN
net = net_cls(
net = DQN(
*envs.get_observation_shape(),
envs.get_action_shape(),
device=device,
@ -247,6 +253,8 @@ class ActorFactoryAtariDQN(ActorFactory):
output_dim=self.hidden_size,
layer_init=layer_init,
)
if self.scale_obs:
net = scale_obs(net)
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)

View File

@ -109,8 +109,7 @@ def test_ppo(args=get_args()):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# define model
net_cls = scale_obs(DQN) if args.scale_obs else DQN
net = net_cls(
net = DQN(
*args.state_shape,
args.action_shape,
device=args.device,
@ -118,6 +117,8 @@ def test_ppo(args=get_args()):
output_dim=args.hidden_size,
layer_init=layer_init,
)
if args.scale_obs:
net = scale_obs(net)
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
critic = Critic(net, device=args.device)
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5)

View File

@ -24,7 +24,7 @@ from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: int = 0,
scale_obs: bool = False,
buffer_size: int = 100000,
actor_lr: float = 1e-5,
critic_lr: float = 1e-5,