from abc import ABC, abstractmethod import numpy as np import torch from torch.optim.lr_scheduler import LambdaLR, LRScheduler from tianshou.highlevel.config import SamplingConfig from tianshou.utils.string import ToStringMixin class LRSchedulerFactory(ToStringMixin, ABC): """Factory for the creation of a learning rate scheduler.""" @abstractmethod def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: pass class LRSchedulerFactoryLinear(LRSchedulerFactory): def __init__(self, sampling_config: SamplingConfig): self.sampling_config = sampling_config def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: return LambdaLR(optim, lr_lambda=self._LRLambda(self.sampling_config).compute) class _LRLambda: def __init__(self, sampling_config: SamplingConfig): self.max_update_num = ( np.ceil(sampling_config.step_per_epoch / sampling_config.step_per_collect) * sampling_config.num_epochs ) def compute(self, epoch: int) -> float: return 1.0 - epoch / self.max_update_num