Make OptimizerFactory more flexible by adding a second method which
allows the creation of an optimizer given arbitrary parameters (rather than a module)
This commit is contained in:
		
							parent
							
								
									bf391853dc
								
							
						
					
					
						commit
						76cbd7efc2
					
				@ -1,11 +1,14 @@
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from typing import Any, Protocol
 | 
			
		||||
from collections.abc import Iterable
 | 
			
		||||
from typing import Any, Protocol, TypeAlias
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch.optim import Adam, RMSprop
 | 
			
		||||
 | 
			
		||||
from tianshou.utils.string import ToStringMixin
 | 
			
		||||
 | 
			
		||||
TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OptimizerWithLearningRateProtocol(Protocol):
 | 
			
		||||
    def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer:
 | 
			
		||||
@ -13,8 +16,15 @@ class OptimizerWithLearningRateProtocol(Protocol):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OptimizerFactory(ABC, ToStringMixin):
 | 
			
		||||
    def create_optimizer(
 | 
			
		||||
        self,
 | 
			
		||||
        module: torch.nn.Module,
 | 
			
		||||
        lr: float,
 | 
			
		||||
    ) -> torch.optim.Optimizer:
 | 
			
		||||
        return self.create_optimizer_for_params(module.parameters(), lr)
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
 | 
			
		||||
    def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -30,8 +40,8 @@ class OptimizerFactoryTorch(OptimizerFactory):
 | 
			
		||||
        self.optim_class = optim_class
 | 
			
		||||
        self.kwargs = kwargs
 | 
			
		||||
 | 
			
		||||
    def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
 | 
			
		||||
        return self.optim_class(module.parameters(), lr=lr, **self.kwargs)
 | 
			
		||||
    def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
 | 
			
		||||
        return self.optim_class(params, lr=lr, **self.kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OptimizerFactoryAdam(OptimizerFactory):
 | 
			
		||||
@ -45,9 +55,9 @@ class OptimizerFactoryAdam(OptimizerFactory):
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
        self.betas = betas
 | 
			
		||||
 | 
			
		||||
    def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
 | 
			
		||||
    def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
 | 
			
		||||
        return Adam(
 | 
			
		||||
            module.parameters(),
 | 
			
		||||
            params,
 | 
			
		||||
            lr=lr,
 | 
			
		||||
            betas=self.betas,
 | 
			
		||||
            eps=self.eps,
 | 
			
		||||
@ -70,9 +80,9 @@ class OptimizerFactoryRMSprop(OptimizerFactory):
 | 
			
		||||
        self.weight_decay = weight_decay
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
 | 
			
		||||
    def create_optimizer(self, module: torch.nn.Module, lr: float) -> RMSprop:
 | 
			
		||||
    def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
 | 
			
		||||
        return RMSprop(
 | 
			
		||||
            module.parameters(),
 | 
			
		||||
            params,
 | 
			
		||||
            lr=lr,
 | 
			
		||||
            alpha=self.alpha,
 | 
			
		||||
            eps=self.eps,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user