Make OptimizerFactory more flexible by adding a second method which

allows the creation of an optimizer given arbitrary parameters
(rather than a module)
This commit is contained in:
Dominik Jain 2024-02-14 20:42:06 +01:00
parent bf391853dc
commit 76cbd7efc2

View File

@ -1,11 +1,14 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Protocol from collections.abc import Iterable
from typing import Any, Protocol, TypeAlias
import torch import torch
from torch.optim import Adam, RMSprop from torch.optim import Adam, RMSprop
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]]
class OptimizerWithLearningRateProtocol(Protocol): class OptimizerWithLearningRateProtocol(Protocol):
def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer: def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer:
@ -13,8 +16,15 @@ class OptimizerWithLearningRateProtocol(Protocol):
class OptimizerFactory(ABC, ToStringMixin): class OptimizerFactory(ABC, ToStringMixin):
def create_optimizer(
self,
module: torch.nn.Module,
lr: float,
) -> torch.optim.Optimizer:
return self.create_optimizer_for_params(module.parameters(), lr)
@abstractmethod @abstractmethod
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer: def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
pass pass
@ -30,8 +40,8 @@ class OptimizerFactoryTorch(OptimizerFactory):
self.optim_class = optim_class self.optim_class = optim_class
self.kwargs = kwargs self.kwargs = kwargs
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer: def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
return self.optim_class(module.parameters(), lr=lr, **self.kwargs) return self.optim_class(params, lr=lr, **self.kwargs)
class OptimizerFactoryAdam(OptimizerFactory): class OptimizerFactoryAdam(OptimizerFactory):
@ -45,9 +55,9 @@ class OptimizerFactoryAdam(OptimizerFactory):
self.eps = eps self.eps = eps
self.betas = betas self.betas = betas
def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam: def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
return Adam( return Adam(
module.parameters(), params,
lr=lr, lr=lr,
betas=self.betas, betas=self.betas,
eps=self.eps, eps=self.eps,
@ -70,9 +80,9 @@ class OptimizerFactoryRMSprop(OptimizerFactory):
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.eps = eps self.eps = eps
def create_optimizer(self, module: torch.nn.Module, lr: float) -> RMSprop: def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
return RMSprop( return RMSprop(
module.parameters(), params,
lr=lr, lr=lr,
alpha=self.alpha, alpha=self.alpha,
eps=self.eps, eps=self.eps,