Fix models using scale_obs not being persistable (due to locally defined class)
This commit is contained in:
parent
7fa588309b
commit
19a98c3b2a
@ -27,7 +27,7 @@ from tianshou.utils.logging import datetime_tag
|
||||
def main(
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "PongNoFrameskip-v4",
|
||||
scale_obs: int = 0,
|
||||
scale_obs: bool = False,
|
||||
eps_test: float = 0.005,
|
||||
eps_train: float = 1.0,
|
||||
eps_train_final: float = 0.05,
|
||||
|
@ -24,7 +24,7 @@ from tianshou.utils.logging import datetime_tag
|
||||
def main(
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "PongNoFrameskip-v4",
|
||||
scale_obs: int = 0,
|
||||
scale_obs: bool = False,
|
||||
eps_test: float = 0.005,
|
||||
eps_train: float = 1.0,
|
||||
eps_train_final: float = 0.05,
|
||||
|
@ -23,19 +23,26 @@ def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.
|
||||
return layer
|
||||
|
||||
|
||||
def scale_obs(module: type[nn.Module], denom: float = 255.0) -> type[nn.Module]:
|
||||
class scaled_module(module):
|
||||
def forward(
|
||||
self,
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
state: Any | None = None,
|
||||
info: dict[str, Any] | None = None,
|
||||
) -> tuple[torch.Tensor, Any]:
|
||||
if info is None:
|
||||
info = {}
|
||||
return super().forward(obs / denom, state, info)
|
||||
class ScaledObsInputModule(torch.nn.Module):
|
||||
def __init__(self, module: torch.nn.Module, denom: float = 255.0):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.denom = denom
|
||||
self.output_dim = module.output_dim
|
||||
|
||||
return scaled_module
|
||||
def forward(
|
||||
self,
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
state: Any | None = None,
|
||||
info: dict[str, Any] | None = None,
|
||||
) -> tuple[torch.Tensor, Any]:
|
||||
if info is None:
|
||||
info = {}
|
||||
return self.module.forward(obs / self.denom, state, info)
|
||||
|
||||
|
||||
def scale_obs(module: nn.Module, denom: float = 255.0) -> nn.Module:
|
||||
return ScaledObsInputModule(module, denom=denom)
|
||||
|
||||
|
||||
class DQN(nn.Module):
|
||||
@ -238,8 +245,7 @@ class ActorFactoryAtariDQN(ActorFactory):
|
||||
self.features_only = features_only
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> Actor:
|
||||
net_cls = scale_obs(DQN) if self.scale_obs else DQN
|
||||
net = net_cls(
|
||||
net = DQN(
|
||||
*envs.get_observation_shape(),
|
||||
envs.get_action_shape(),
|
||||
device=device,
|
||||
@ -247,6 +253,8 @@ class ActorFactoryAtariDQN(ActorFactory):
|
||||
output_dim=self.hidden_size,
|
||||
layer_init=layer_init,
|
||||
)
|
||||
if self.scale_obs:
|
||||
net = scale_obs(net)
|
||||
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
|
||||
|
||||
|
||||
|
@ -109,8 +109,7 @@ def test_ppo(args=get_args()):
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
# define model
|
||||
net_cls = scale_obs(DQN) if args.scale_obs else DQN
|
||||
net = net_cls(
|
||||
net = DQN(
|
||||
*args.state_shape,
|
||||
args.action_shape,
|
||||
device=args.device,
|
||||
@ -118,6 +117,8 @@ def test_ppo(args=get_args()):
|
||||
output_dim=args.hidden_size,
|
||||
layer_init=layer_init,
|
||||
)
|
||||
if args.scale_obs:
|
||||
net = scale_obs(net)
|
||||
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
|
||||
critic = Critic(net, device=args.device)
|
||||
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5)
|
||||
|
@ -24,7 +24,7 @@ from tianshou.utils.logging import datetime_tag
|
||||
def main(
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "PongNoFrameskip-v4",
|
||||
scale_obs: int = 0,
|
||||
scale_obs: bool = False,
|
||||
buffer_size: int = 100000,
|
||||
actor_lr: float = 1e-5,
|
||||
critic_lr: float = 1e-5,
|
||||
|
Loading…
x
Reference in New Issue
Block a user