Fix an issue where policies built with LRSchedulerFactoryLinear were not picklable (#992)
- [X] I have marked all applicable categories: + [X] exception-raising fix + [ ] algorithm implementation fix + [ ] documentation modification + [ ] new feature - [X] I have reformatted the code using `make format` (**required**) - [X] I have checked the code using `make commit-checks` (**required**) - [ ] If applicable, I have mentioned the relevant/related issue(s) - [ ] If applicable, I have listed every items in this Pull Request below The cause was the use of a lambda function in the state of a generated object.
This commit is contained in:
parent
962c6d1e11
commit
6d6c85e594
@ -188,3 +188,4 @@ MLP
|
||||
backpropagation
|
||||
dataclass
|
||||
superset
|
||||
picklable
|
||||
|
@ -27,7 +27,7 @@ def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="Pendulum-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--seed", type=int, default=1)
|
||||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
|
||||
parser.add_argument("--actor-lr", type=float, default=1e-3)
|
||||
parser.add_argument("--critic-lr", type=float, default=1e-3)
|
||||
|
@ -21,8 +21,14 @@ class LRSchedulerFactoryLinear(LRSchedulerFactory):
|
||||
self.sampling_config = sampling_config
|
||||
|
||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
||||
max_update_num = (
|
||||
np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
|
||||
* self.sampling_config.num_epochs
|
||||
)
|
||||
return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
return LambdaLR(optim, lr_lambda=self._LRLambda(self.sampling_config).compute)
|
||||
|
||||
class _LRLambda:
|
||||
def __init__(self, sampling_config: SamplingConfig):
|
||||
self.max_update_num = (
|
||||
np.ceil(sampling_config.step_per_epoch / sampling_config.step_per_collect)
|
||||
* sampling_config.num_epochs
|
||||
)
|
||||
|
||||
def compute(self, epoch: int) -> float:
|
||||
return 1.0 - epoch / self.max_update_num
|
||||
|
@ -61,12 +61,15 @@ class PolicyPersistence:
|
||||
|
||||
class Mode(Enum):
|
||||
POLICY_STATE_DICT = "policy_state_dict"
|
||||
"""Persist only the policy's state dictionary"""
|
||||
"""Persist only the policy's state dictionary. Note that for a policy to be restored from
|
||||
such a dictionary, it is necessary to first create a structurally equivalent object which can
|
||||
accept the respective state."""
|
||||
POLICY = "policy"
|
||||
"""Persist the entire policy. This is larger but has the advantage of the policy being loadable
|
||||
without requiring an environment to be instantiated.
|
||||
It has the potential disadvantage that upon breaking code changes in the policy implementation
|
||||
(e.g. renamed/moved class), it will no longer be loadable.
|
||||
Note that a precondition is that the policy be picklable in its entirety.
|
||||
"""
|
||||
|
||||
def get_filename(self) -> str:
|
||||
|
Loading…
x
Reference in New Issue
Block a user