Dominik Jain 78b6dd1f49 Adapt class naming scheme
* Use prefix convention (subclasses have superclass names as prefix) to
  facilitate discoverability of relevant classes via IDE autocompletion
* Use dual naming, adding an alternative concise name that omits the
  precise OO semantics and retains only the essential part of the name
  (which can be more pleasing to users not accustomed to
  convoluted OO naming)
2023-10-18 20:44:16 +02:00

46 lines
1.6 KiB
Python

from abc import ABC, abstractmethod
from typing import Any
import torch
from torch.optim import Adam
class OptimizerFactory(ABC):
# TODO: Is it OK to assume that all optimizers have a learning rate argument?
# Right now, the learning rate is typically a configuration parameter.
# If we drop the assumption, we can't have that and will need to move the parameter
# to the optimizer factory, which is inconvenient for the user.
@abstractmethod
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
pass
class OptimizerFactoryTorch(OptimizerFactory):
def __init__(self, optim_class: Any, **kwargs):
""":param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
which will be passed the module parameters, the learning rate as `lr` and the
kwargs provided.
:param kwargs: keyword arguments to provide at optimizer construction
"""
self.optim_class = optim_class
self.kwargs = kwargs
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
return self.optim_class(module.parameters(), lr=lr, **self.kwargs)
class OptimizerFactoryAdam(OptimizerFactory):
def __init__(self, betas=(0.9, 0.999), eps=1e-08, weight_decay=0):
self.weight_decay = weight_decay
self.eps = eps
self.betas = betas
def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
return Adam(
module.parameters(),
lr=lr,
betas=self.betas,
eps=self.eps,
weight_decay=self.weight_decay,
)