Make OptimizerFactory more flexible by adding a second method which
allows the creation of an optimizer given arbitrary parameters (rather than a module)
This commit is contained in:
parent
bf391853dc
commit
76cbd7efc2
@ -1,11 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Protocol
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Protocol, TypeAlias
|
||||
|
||||
import torch
|
||||
from torch.optim import Adam, RMSprop
|
||||
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]]
|
||||
|
||||
|
||||
class OptimizerWithLearningRateProtocol(Protocol):
|
||||
def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer:
|
||||
@ -13,8 +16,15 @@ class OptimizerWithLearningRateProtocol(Protocol):
|
||||
|
||||
|
||||
class OptimizerFactory(ABC, ToStringMixin):
|
||||
def create_optimizer(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
lr: float,
|
||||
) -> torch.optim.Optimizer:
|
||||
return self.create_optimizer_for_params(module.parameters(), lr)
|
||||
|
||||
@abstractmethod
|
||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
|
||||
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
|
||||
pass
|
||||
|
||||
|
||||
@ -30,8 +40,8 @@ class OptimizerFactoryTorch(OptimizerFactory):
|
||||
self.optim_class = optim_class
|
||||
self.kwargs = kwargs
|
||||
|
||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
|
||||
return self.optim_class(module.parameters(), lr=lr, **self.kwargs)
|
||||
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
|
||||
return self.optim_class(params, lr=lr, **self.kwargs)
|
||||
|
||||
|
||||
class OptimizerFactoryAdam(OptimizerFactory):
|
||||
@ -45,9 +55,9 @@ class OptimizerFactoryAdam(OptimizerFactory):
|
||||
self.eps = eps
|
||||
self.betas = betas
|
||||
|
||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
|
||||
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
|
||||
return Adam(
|
||||
module.parameters(),
|
||||
params,
|
||||
lr=lr,
|
||||
betas=self.betas,
|
||||
eps=self.eps,
|
||||
@ -70,9 +80,9 @@ class OptimizerFactoryRMSprop(OptimizerFactory):
|
||||
self.weight_decay = weight_decay
|
||||
self.eps = eps
|
||||
|
||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> RMSprop:
|
||||
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
|
||||
return RMSprop(
|
||||
module.parameters(),
|
||||
params,
|
||||
lr=lr,
|
||||
alpha=self.alpha,
|
||||
eps=self.eps,
|
||||
|
Loading…
x
Reference in New Issue
Block a user