55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Union, Iterable, Dict, Any, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.optim import Adam
|
|
from torch.optim.lr_scheduler import LRScheduler, LambdaLR
|
|
|
|
from tianshou.config import RLSamplingConfig, NNConfig
|
|
|
|
TParams = Union[Iterable[Tensor], Iterable[Dict[str, Any]]]
|
|
|
|
|
|
class OptimizerFactory(ABC):
|
|
@abstractmethod
|
|
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer:
|
|
pass
|
|
|
|
|
|
class TorchOptimizerFactory(OptimizerFactory):
|
|
def __init__(self, optim_class, **kwargs):
|
|
self.optim_class = optim_class
|
|
self.kwargs = kwargs
|
|
|
|
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer:
|
|
return self.optim_class(module.parameters(), **self.kwargs)
|
|
|
|
|
|
class AdamOptimizerFactory(OptimizerFactory):
|
|
def __init__(self, lr):
|
|
self.lr = lr
|
|
|
|
def create_optimizer(self, module: torch.nn.Module) -> Adam:
|
|
return Adam(module.parameters(), lr=self.lr)
|
|
|
|
|
|
class LRSchedulerFactory(ABC):
|
|
@abstractmethod
|
|
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]:
|
|
pass
|
|
|
|
|
|
class LinearLRSchedulerFactory(LRSchedulerFactory):
|
|
def __init__(self, nn_config: NNConfig, sampling_config: RLSamplingConfig):
|
|
self.nn_config = nn_config
|
|
self.sampling_config = sampling_config
|
|
|
|
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]:
|
|
lr_scheduler = None
|
|
if self.nn_config.lr_decay:
|
|
max_update_num = np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) * self.sampling_config.num_epochs
|
|
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
|
return lr_scheduler
|