From 6d6c85e594d7fae15168f5b53b5889ef938c9378 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 14 Nov 2023 19:23:18 +0100 Subject: [PATCH] Fix an issue where policies built with LRSchedulerFactoryLinear were not picklable (#992) - [X] I have marked all applicable categories: + [X] exception-raising fix + [ ] algorithm implementation fix + [ ] documentation modification + [ ] new feature - [X] I have reformatted the code using `make format` (**required**) - [X] I have checked the code using `make commit-checks` (**required**) - [ ] If applicable, I have mentioned the relevant/related issue(s) - [ ] If applicable, I have listed every items in this Pull Request below The cause was the use of a lambda function in the state of a generated object. --- docs/spelling_wordlist.txt | 1 + test/offline/test_cql.py | 2 +- tianshou/highlevel/params/lr_scheduler.py | 16 +++++++++++----- tianshou/highlevel/persistence.py | 5 ++++- 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index a19949a..782d8cd 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -188,3 +188,4 @@ MLP backpropagation dataclass superset +picklable diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 7c82668..1a7c86b 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -27,7 +27,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) - parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--seed", type=int, default=1) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--actor-lr", type=float, default=1e-3) parser.add_argument("--critic-lr", type=float, default=1e-3) diff --git a/tianshou/highlevel/params/lr_scheduler.py b/tianshou/highlevel/params/lr_scheduler.py index 286df4c..0b0cf35 100644 --- a/tianshou/highlevel/params/lr_scheduler.py +++ b/tianshou/highlevel/params/lr_scheduler.py @@ -21,8 +21,14 @@ class LRSchedulerFactoryLinear(LRSchedulerFactory): self.sampling_config = sampling_config def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: - max_update_num = ( - np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) - * self.sampling_config.num_epochs - ) - return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + return LambdaLR(optim, lr_lambda=self._LRLambda(self.sampling_config).compute) + + class _LRLambda: + def __init__(self, sampling_config: SamplingConfig): + self.max_update_num = ( + np.ceil(sampling_config.step_per_epoch / sampling_config.step_per_collect) + * sampling_config.num_epochs + ) + + def compute(self, epoch: int) -> float: + return 1.0 - epoch / self.max_update_num diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index 686a26c..4fc4f9c 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -61,12 +61,15 @@ class PolicyPersistence: class Mode(Enum): POLICY_STATE_DICT = "policy_state_dict" - """Persist only the policy's state dictionary""" + """Persist only the policy's state dictionary. Note that for a policy to be restored from + such a dictionary, it is necessary to first create a structurally equivalent object which can + accept the respective state.""" POLICY = "policy" """Persist the entire policy. This is larger but has the advantage of the policy being loadable without requiring an environment to be instantiated. It has the potential disadvantage that upon breaking code changes in the policy implementation (e.g. renamed/moved class), it will no longer be loadable. + Note that a precondition is that the policy be picklable in its entirety. """ def get_filename(self) -> str: