- [X] I have marked all applicable categories: + [X] exception-raising fix + [ ] algorithm implementation fix + [ ] documentation modification + [ ] new feature - [X] I have reformatted the code using `make format` (**required**) - [X] I have checked the code using `make commit-checks` (**required**) - [ ] If applicable, I have mentioned the relevant/related issue(s) - [ ] If applicable, I have listed every items in this Pull Request below The cause was the use of a lambda function in the state of a generated object.
35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
from abc import ABC, abstractmethod
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
|
|
|
from tianshou.highlevel.config import SamplingConfig
|
|
from tianshou.utils.string import ToStringMixin
|
|
|
|
|
|
class LRSchedulerFactory(ToStringMixin, ABC):
|
|
"""Factory for the creation of a learning rate scheduler."""
|
|
|
|
@abstractmethod
|
|
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
|
pass
|
|
|
|
|
|
class LRSchedulerFactoryLinear(LRSchedulerFactory):
|
|
def __init__(self, sampling_config: SamplingConfig):
|
|
self.sampling_config = sampling_config
|
|
|
|
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
|
return LambdaLR(optim, lr_lambda=self._LRLambda(self.sampling_config).compute)
|
|
|
|
class _LRLambda:
|
|
def __init__(self, sampling_config: SamplingConfig):
|
|
self.max_update_num = (
|
|
np.ceil(sampling_config.step_per_epoch / sampling_config.step_per_collect)
|
|
* sampling_config.num_epochs
|
|
)
|
|
|
|
def compute(self, epoch: int) -> float:
|
|
return 1.0 - epoch / self.max_update_num
|