Tianshou/tianshou/highlevel/params/lr_scheduler.py
Dominik Jain 78b6dd1f49 Adapt class naming scheme
* Use prefix convention (subclasses have superclass names as prefix) to
  facilitate discoverability of relevant classes via IDE autocompletion
* Use dual naming, adding an alternative concise name that omits the
  precise OO semantics and retains only the essential part of the name
  (which can be more pleasing to users not accustomed to
  convoluted OO naming)
2023-10-18 20:44:16 +02:00

26 lines
828 B
Python

from abc import ABC, abstractmethod
import numpy as np
import torch
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from tianshou.highlevel.config import RLSamplingConfig
class LRSchedulerFactory(ABC):
@abstractmethod
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
pass
class LRSchedulerFactoryLinear(LRSchedulerFactory):
def __init__(self, sampling_config: RLSamplingConfig):
self.sampling_config = sampling_config
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
max_update_num = (
np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
* self.sampling_config.num_epochs
)
return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)