Fix an issue where policies built with LRSchedulerFactoryLinear were not picklable (#992)

- [X] I have marked all applicable categories:
    + [X] exception-raising fix
    + [ ] algorithm implementation fix
    + [ ] documentation modification
    + [ ] new feature
- [X] I have reformatted the code using `make format` (**required**)
- [X] I have checked the code using `make commit-checks` (**required**)
- [ ] If applicable, I have mentioned the relevant/related issue(s)
- [ ] If applicable, I have listed every items in this Pull Request
below

The cause was the use of a lambda function in the state of a generated
object.
This commit is contained in:
Dominik Jain 2023-11-14 19:23:18 +01:00 committed by GitHub
parent 962c6d1e11
commit 6d6c85e594
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 17 additions and 7 deletions

View File

@ -188,3 +188,4 @@ MLP
backpropagation
dataclass
superset
picklable

View File

@ -27,7 +27,7 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Pendulum-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
parser.add_argument("--actor-lr", type=float, default=1e-3)
parser.add_argument("--critic-lr", type=float, default=1e-3)

View File

@ -21,8 +21,14 @@ class LRSchedulerFactoryLinear(LRSchedulerFactory):
self.sampling_config = sampling_config
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
max_update_num = (
np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
* self.sampling_config.num_epochs
)
return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
return LambdaLR(optim, lr_lambda=self._LRLambda(self.sampling_config).compute)
class _LRLambda:
def __init__(self, sampling_config: SamplingConfig):
self.max_update_num = (
np.ceil(sampling_config.step_per_epoch / sampling_config.step_per_collect)
* sampling_config.num_epochs
)
def compute(self, epoch: int) -> float:
return 1.0 - epoch / self.max_update_num

View File

@ -61,12 +61,15 @@ class PolicyPersistence:
class Mode(Enum):
POLICY_STATE_DICT = "policy_state_dict"
"""Persist only the policy's state dictionary"""
"""Persist only the policy's state dictionary. Note that for a policy to be restored from
such a dictionary, it is necessary to first create a structurally equivalent object which can
accept the respective state."""
POLICY = "policy"
"""Persist the entire policy. This is larger but has the advantage of the policy being loadable
without requiring an environment to be instantiated.
It has the potential disadvantage that upon breaking code changes in the policy implementation
(e.g. renamed/moved class), it will no longer be loadable.
Note that a precondition is that the policy be picklable in its entirety.
"""
def get_filename(self) -> str: