Make OptimizerFactory more flexible by adding a second method which
allows the creation of an optimizer given arbitrary parameters (rather than a module)
This commit is contained in:
parent
bf391853dc
commit
76cbd7efc2
@ -1,11 +1,14 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Protocol
|
from collections.abc import Iterable
|
||||||
|
from typing import Any, Protocol, TypeAlias
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import Adam, RMSprop
|
from torch.optim import Adam, RMSprop
|
||||||
|
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class OptimizerWithLearningRateProtocol(Protocol):
|
class OptimizerWithLearningRateProtocol(Protocol):
|
||||||
def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer:
|
def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer:
|
||||||
@ -13,8 +16,15 @@ class OptimizerWithLearningRateProtocol(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
class OptimizerFactory(ABC, ToStringMixin):
|
class OptimizerFactory(ABC, ToStringMixin):
|
||||||
|
def create_optimizer(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
lr: float,
|
||||||
|
) -> torch.optim.Optimizer:
|
||||||
|
return self.create_optimizer_for_params(module.parameters(), lr)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
|
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -30,8 +40,8 @@ class OptimizerFactoryTorch(OptimizerFactory):
|
|||||||
self.optim_class = optim_class
|
self.optim_class = optim_class
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
|
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
|
||||||
return self.optim_class(module.parameters(), lr=lr, **self.kwargs)
|
return self.optim_class(params, lr=lr, **self.kwargs)
|
||||||
|
|
||||||
|
|
||||||
class OptimizerFactoryAdam(OptimizerFactory):
|
class OptimizerFactoryAdam(OptimizerFactory):
|
||||||
@ -45,9 +55,9 @@ class OptimizerFactoryAdam(OptimizerFactory):
|
|||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.betas = betas
|
self.betas = betas
|
||||||
|
|
||||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
|
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
|
||||||
return Adam(
|
return Adam(
|
||||||
module.parameters(),
|
params,
|
||||||
lr=lr,
|
lr=lr,
|
||||||
betas=self.betas,
|
betas=self.betas,
|
||||||
eps=self.eps,
|
eps=self.eps,
|
||||||
@ -70,9 +80,9 @@ class OptimizerFactoryRMSprop(OptimizerFactory):
|
|||||||
self.weight_decay = weight_decay
|
self.weight_decay = weight_decay
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> RMSprop:
|
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
|
||||||
return RMSprop(
|
return RMSprop(
|
||||||
module.parameters(),
|
params,
|
||||||
lr=lr,
|
lr=lr,
|
||||||
alpha=self.alpha,
|
alpha=self.alpha,
|
||||||
eps=self.eps,
|
eps=self.eps,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user