Fix models using scale_obs not being persistable (due to locally defined class)
This commit is contained in:
parent
7fa588309b
commit
19a98c3b2a
@ -27,7 +27,7 @@ from tianshou.utils.logging import datetime_tag
|
|||||||
def main(
|
def main(
|
||||||
experiment_config: ExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
task: str = "PongNoFrameskip-v4",
|
task: str = "PongNoFrameskip-v4",
|
||||||
scale_obs: int = 0,
|
scale_obs: bool = False,
|
||||||
eps_test: float = 0.005,
|
eps_test: float = 0.005,
|
||||||
eps_train: float = 1.0,
|
eps_train: float = 1.0,
|
||||||
eps_train_final: float = 0.05,
|
eps_train_final: float = 0.05,
|
||||||
|
@ -24,7 +24,7 @@ from tianshou.utils.logging import datetime_tag
|
|||||||
def main(
|
def main(
|
||||||
experiment_config: ExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
task: str = "PongNoFrameskip-v4",
|
task: str = "PongNoFrameskip-v4",
|
||||||
scale_obs: int = 0,
|
scale_obs: bool = False,
|
||||||
eps_test: float = 0.005,
|
eps_test: float = 0.005,
|
||||||
eps_train: float = 1.0,
|
eps_train: float = 1.0,
|
||||||
eps_train_final: float = 0.05,
|
eps_train_final: float = 0.05,
|
||||||
|
@ -23,19 +23,26 @@ def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.
|
|||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
def scale_obs(module: type[nn.Module], denom: float = 255.0) -> type[nn.Module]:
|
class ScaledObsInputModule(torch.nn.Module):
|
||||||
class scaled_module(module):
|
def __init__(self, module: torch.nn.Module, denom: float = 255.0):
|
||||||
def forward(
|
super().__init__()
|
||||||
self,
|
self.module = module
|
||||||
obs: np.ndarray | torch.Tensor,
|
self.denom = denom
|
||||||
state: Any | None = None,
|
self.output_dim = module.output_dim
|
||||||
info: dict[str, Any] | None = None,
|
|
||||||
) -> tuple[torch.Tensor, Any]:
|
|
||||||
if info is None:
|
|
||||||
info = {}
|
|
||||||
return super().forward(obs / denom, state, info)
|
|
||||||
|
|
||||||
return scaled_module
|
def forward(
|
||||||
|
self,
|
||||||
|
obs: np.ndarray | torch.Tensor,
|
||||||
|
state: Any | None = None,
|
||||||
|
info: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[torch.Tensor, Any]:
|
||||||
|
if info is None:
|
||||||
|
info = {}
|
||||||
|
return self.module.forward(obs / self.denom, state, info)
|
||||||
|
|
||||||
|
|
||||||
|
def scale_obs(module: nn.Module, denom: float = 255.0) -> nn.Module:
|
||||||
|
return ScaledObsInputModule(module, denom=denom)
|
||||||
|
|
||||||
|
|
||||||
class DQN(nn.Module):
|
class DQN(nn.Module):
|
||||||
@ -238,8 +245,7 @@ class ActorFactoryAtariDQN(ActorFactory):
|
|||||||
self.features_only = features_only
|
self.features_only = features_only
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> Actor:
|
def create_module(self, envs: Environments, device: TDevice) -> Actor:
|
||||||
net_cls = scale_obs(DQN) if self.scale_obs else DQN
|
net = DQN(
|
||||||
net = net_cls(
|
|
||||||
*envs.get_observation_shape(),
|
*envs.get_observation_shape(),
|
||||||
envs.get_action_shape(),
|
envs.get_action_shape(),
|
||||||
device=device,
|
device=device,
|
||||||
@ -247,6 +253,8 @@ class ActorFactoryAtariDQN(ActorFactory):
|
|||||||
output_dim=self.hidden_size,
|
output_dim=self.hidden_size,
|
||||||
layer_init=layer_init,
|
layer_init=layer_init,
|
||||||
)
|
)
|
||||||
|
if self.scale_obs:
|
||||||
|
net = scale_obs(net)
|
||||||
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
|
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
@ -109,8 +109,7 @@ def test_ppo(args=get_args()):
|
|||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
# define model
|
# define model
|
||||||
net_cls = scale_obs(DQN) if args.scale_obs else DQN
|
net = DQN(
|
||||||
net = net_cls(
|
|
||||||
*args.state_shape,
|
*args.state_shape,
|
||||||
args.action_shape,
|
args.action_shape,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
@ -118,6 +117,8 @@ def test_ppo(args=get_args()):
|
|||||||
output_dim=args.hidden_size,
|
output_dim=args.hidden_size,
|
||||||
layer_init=layer_init,
|
layer_init=layer_init,
|
||||||
)
|
)
|
||||||
|
if args.scale_obs:
|
||||||
|
net = scale_obs(net)
|
||||||
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
|
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
|
||||||
critic = Critic(net, device=args.device)
|
critic = Critic(net, device=args.device)
|
||||||
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5)
|
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5)
|
||||||
|
@ -24,7 +24,7 @@ from tianshou.utils.logging import datetime_tag
|
|||||||
def main(
|
def main(
|
||||||
experiment_config: ExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
task: str = "PongNoFrameskip-v4",
|
task: str = "PongNoFrameskip-v4",
|
||||||
scale_obs: int = 0,
|
scale_obs: bool = False,
|
||||||
buffer_size: int = 100000,
|
buffer_size: int = 100000,
|
||||||
actor_lr: float = 1e-5,
|
actor_lr: float = 1e-5,
|
||||||
critic_lr: float = 1e-5,
|
critic_lr: float = 1e-5,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user