52 lines
1.6 KiB
Python
Raw Normal View History

from abc import ABC, abstractmethod
2023-09-20 09:29:34 +02:00
from collections.abc import Iterable
2023-09-20 13:15:06 +02:00
from typing import Any
import numpy as np
import torch
from torch import Tensor
from torch.optim import Adam
2023-09-20 09:29:34 +02:00
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
2023-09-20 13:15:06 +02:00
from tianshou.highlevel.experiment import RLSamplingConfig
2023-09-20 09:29:34 +02:00
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
class OptimizerFactory(ABC):
@abstractmethod
2023-09-20 09:29:34 +02:00
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
pass
class TorchOptimizerFactory(OptimizerFactory):
2023-09-20 09:29:34 +02:00
def __init__(self, optim_class: Any, **kwargs):
self.optim_class = optim_class
self.kwargs = kwargs
2023-09-20 09:29:34 +02:00
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
return self.optim_class(module.parameters(), lr=lr, **self.kwargs)
class AdamOptimizerFactory(OptimizerFactory):
2023-09-20 09:29:34 +02:00
def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
return Adam(module.parameters(), lr=lr)
class LRSchedulerFactory(ABC):
@abstractmethod
2023-09-20 09:29:34 +02:00
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
pass
class LinearLRSchedulerFactory(LRSchedulerFactory):
2023-09-20 09:29:34 +02:00
def __init__(self, sampling_config: RLSamplingConfig):
self.sampling_config = sampling_config
2023-09-20 09:29:34 +02:00
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
max_update_num = (
np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
* self.sampling_config.num_epochs
)
return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)