* Use prefix convention (subclasses have superclass names as prefix) to facilitate discoverability of relevant classes via IDE autocompletion * Use dual naming, adding an alternative concise name that omits the precise OO semantics and retains only the essential part of the name (which can be more pleasing to users not accustomed to convoluted OO naming)
26 lines
828 B
Python
26 lines
828 B
Python
from abc import ABC, abstractmethod
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
|
|
|
from tianshou.highlevel.config import RLSamplingConfig
|
|
|
|
|
|
class LRSchedulerFactory(ABC):
|
|
@abstractmethod
|
|
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
|
pass
|
|
|
|
|
|
class LRSchedulerFactoryLinear(LRSchedulerFactory):
|
|
def __init__(self, sampling_config: RLSamplingConfig):
|
|
self.sampling_config = sampling_config
|
|
|
|
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
|
max_update_num = (
|
|
np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
|
|
* self.sampling_config.num_epochs
|
|
)
|
|
return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|