55 lines
1.8 KiB
Python

from abc import ABC, abstractmethod
from typing import Union, Iterable, Dict, Any, Optional
import numpy as np
import torch
from torch import Tensor
from torch.optim import Adam
from torch.optim.lr_scheduler import LRScheduler, LambdaLR
from tianshou.config import RLSamplingConfig, NNConfig
TParams = Union[Iterable[Tensor], Iterable[Dict[str, Any]]]
class OptimizerFactory(ABC):
@abstractmethod
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer:
pass
class TorchOptimizerFactory(OptimizerFactory):
def __init__(self, optim_class, **kwargs):
self.optim_class = optim_class
self.kwargs = kwargs
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer:
return self.optim_class(module.parameters(), **self.kwargs)
class AdamOptimizerFactory(OptimizerFactory):
def __init__(self, lr):
self.lr = lr
def create_optimizer(self, module: torch.nn.Module) -> Adam:
return Adam(module.parameters(), lr=self.lr)
class LRSchedulerFactory(ABC):
@abstractmethod
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]:
pass
class LinearLRSchedulerFactory(LRSchedulerFactory):
def __init__(self, nn_config: NNConfig, sampling_config: RLSamplingConfig):
self.nn_config = nn_config
self.sampling_config = sampling_config
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]:
lr_scheduler = None
if self.nn_config.lr_decay:
max_update_num = np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) * self.sampling_config.num_epochs
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
return lr_scheduler