diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index 0e754b1..db5fd90 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -1,11 +1,14 @@ from abc import ABC, abstractmethod -from typing import Any, Protocol +from collections.abc import Iterable +from typing import Any, Protocol, TypeAlias import torch from torch.optim import Adam, RMSprop from tianshou.utils.string import ToStringMixin +TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]] + class OptimizerWithLearningRateProtocol(Protocol): def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer: @@ -13,8 +16,15 @@ class OptimizerWithLearningRateProtocol(Protocol): class OptimizerFactory(ABC, ToStringMixin): + def create_optimizer( + self, + module: torch.nn.Module, + lr: float, + ) -> torch.optim.Optimizer: + return self.create_optimizer_for_params(module.parameters(), lr) + @abstractmethod - def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer: + def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: pass @@ -30,8 +40,8 @@ class OptimizerFactoryTorch(OptimizerFactory): self.optim_class = optim_class self.kwargs = kwargs - def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer: - return self.optim_class(module.parameters(), lr=lr, **self.kwargs) + def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: + return self.optim_class(params, lr=lr, **self.kwargs) class OptimizerFactoryAdam(OptimizerFactory): @@ -45,9 +55,9 @@ class OptimizerFactoryAdam(OptimizerFactory): self.eps = eps self.betas = betas - def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam: + def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: return Adam( - module.parameters(), + params, lr=lr, betas=self.betas, eps=self.eps, @@ -70,9 +80,9 @@ class OptimizerFactoryRMSprop(OptimizerFactory): self.weight_decay = weight_decay self.eps = eps - def create_optimizer(self, module: torch.nn.Module, lr: float) -> RMSprop: + def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: return RMSprop( - module.parameters(), + params, lr=lr, alpha=self.alpha, eps=self.eps,