2023-09-25 17:56:37 +02:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
2023-09-26 15:35:18 +02:00
|
|
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
2023-09-25 17:56:37 +02:00
|
|
|
|
2023-10-06 13:50:23 +02:00
|
|
|
from tianshou.highlevel.config import SamplingConfig
|
2023-11-07 10:54:22 +01:00
|
|
|
from tianshou.utils.string import ToStringMixin
|
2023-09-25 17:56:37 +02:00
|
|
|
|
|
|
|
|
2023-10-05 13:15:24 +02:00
|
|
|
class LRSchedulerFactory(ToStringMixin, ABC):
|
2023-10-18 22:07:40 +02:00
|
|
|
"""Factory for the creation of a learning rate scheduler."""
|
2023-10-16 18:19:31 +02:00
|
|
|
|
2023-09-25 17:56:37 +02:00
|
|
|
@abstractmethod
|
|
|
|
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2023-09-27 17:20:35 +02:00
|
|
|
class LRSchedulerFactoryLinear(LRSchedulerFactory):
|
2023-10-06 13:50:23 +02:00
|
|
|
def __init__(self, sampling_config: SamplingConfig):
|
2023-09-25 17:56:37 +02:00
|
|
|
self.sampling_config = sampling_config
|
|
|
|
|
|
|
|
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
|
|
|
max_update_num = (
|
|
|
|
np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
|
|
|
|
* self.sampling_config.num_epochs
|
|
|
|
)
|
|
|
|
return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|