Make OptimizerFactory more flexible by adding a second method which

allows the creation of an optimizer given arbitrary parameters
(rather than a module)
This commit is contained in:
Dominik Jain 2024-02-14 20:42:06 +01:00
parent bf391853dc
commit 76cbd7efc2

View File

@ -1,11 +1,14 @@
from abc import ABC, abstractmethod
from typing import Any, Protocol
from collections.abc import Iterable
from typing import Any, Protocol, TypeAlias
import torch
from torch.optim import Adam, RMSprop
from tianshou.utils.string import ToStringMixin
TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]]
class OptimizerWithLearningRateProtocol(Protocol):
def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer:
@ -13,8 +16,15 @@ class OptimizerWithLearningRateProtocol(Protocol):
class OptimizerFactory(ABC, ToStringMixin):
def create_optimizer(
self,
module: torch.nn.Module,
lr: float,
) -> torch.optim.Optimizer:
return self.create_optimizer_for_params(module.parameters(), lr)
@abstractmethod
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
pass
@ -30,8 +40,8 @@ class OptimizerFactoryTorch(OptimizerFactory):
self.optim_class = optim_class
self.kwargs = kwargs
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
return self.optim_class(module.parameters(), lr=lr, **self.kwargs)
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
return self.optim_class(params, lr=lr, **self.kwargs)
class OptimizerFactoryAdam(OptimizerFactory):
@ -45,9 +55,9 @@ class OptimizerFactoryAdam(OptimizerFactory):
self.eps = eps
self.betas = betas
def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
return Adam(
module.parameters(),
params,
lr=lr,
betas=self.betas,
eps=self.eps,
@ -70,9 +80,9 @@ class OptimizerFactoryRMSprop(OptimizerFactory):
self.weight_decay = weight_decay
self.eps = eps
def create_optimizer(self, module: torch.nn.Module, lr: float) -> RMSprop:
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
return RMSprop(
module.parameters(),
params,
lr=lr,
alpha=self.alpha,
eps=self.eps,