Tianshou/tianshou/highlevel/params/lr_scheduler.py
Dominik Jain 6d6c85e594
Fix an issue where policies built with LRSchedulerFactoryLinear were not picklable (#992)
- [X] I have marked all applicable categories:
    + [X] exception-raising fix
    + [ ] algorithm implementation fix
    + [ ] documentation modification
    + [ ] new feature
- [X] I have reformatted the code using `make format` (**required**)
- [X] I have checked the code using `make commit-checks` (**required**)
- [ ] If applicable, I have mentioned the relevant/related issue(s)
- [ ] If applicable, I have listed every items in this Pull Request
below

The cause was the use of a lambda function in the state of a generated
object.
2023-11-14 10:23:18 -08:00

35 lines
1.1 KiB
Python

from abc import ABC, abstractmethod
import numpy as np
import torch
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from tianshou.highlevel.config import SamplingConfig
from tianshou.utils.string import ToStringMixin
class LRSchedulerFactory(ToStringMixin, ABC):
"""Factory for the creation of a learning rate scheduler."""
@abstractmethod
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
pass
class LRSchedulerFactoryLinear(LRSchedulerFactory):
def __init__(self, sampling_config: SamplingConfig):
self.sampling_config = sampling_config
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
return LambdaLR(optim, lr_lambda=self._LRLambda(self.sampling_config).compute)
class _LRLambda:
def __init__(self, sampling_config: SamplingConfig):
self.max_update_num = (
np.ceil(sampling_config.step_per_epoch / sampling_config.step_per_collect)
* sampling_config.num_epochs
)
def compute(self, epoch: int) -> float:
return 1.0 - epoch / self.max_update_num