diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index a19949a..782d8cd 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -188,3 +188,4 @@ MLP backpropagation dataclass superset +picklable diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 7c82668..1a7c86b 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -27,7 +27,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) - parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--seed", type=int, default=1) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--actor-lr", type=float, default=1e-3) parser.add_argument("--critic-lr", type=float, default=1e-3) diff --git a/tianshou/highlevel/params/lr_scheduler.py b/tianshou/highlevel/params/lr_scheduler.py index 286df4c..0b0cf35 100644 --- a/tianshou/highlevel/params/lr_scheduler.py +++ b/tianshou/highlevel/params/lr_scheduler.py @@ -21,8 +21,14 @@ class LRSchedulerFactoryLinear(LRSchedulerFactory): self.sampling_config = sampling_config def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: - max_update_num = ( - np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) - * self.sampling_config.num_epochs - ) - return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + return LambdaLR(optim, lr_lambda=self._LRLambda(self.sampling_config).compute) + + class _LRLambda: + def __init__(self, sampling_config: SamplingConfig): + self.max_update_num = ( + np.ceil(sampling_config.step_per_epoch / sampling_config.step_per_collect) + * sampling_config.num_epochs + ) + + def compute(self, epoch: int) -> float: + return 1.0 - epoch / self.max_update_num diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index 686a26c..4fc4f9c 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -61,12 +61,15 @@ class PolicyPersistence: class Mode(Enum): POLICY_STATE_DICT = "policy_state_dict" - """Persist only the policy's state dictionary""" + """Persist only the policy's state dictionary. Note that for a policy to be restored from + such a dictionary, it is necessary to first create a structurally equivalent object which can + accept the respective state.""" POLICY = "policy" """Persist the entire policy. This is larger but has the advantage of the policy being loadable without requiring an environment to be instantiated. It has the potential disadvantage that upon breaking code changes in the policy implementation (e.g. renamed/moved class), it will no longer be loadable. + Note that a precondition is that the policy be picklable in its entirety. """ def get_filename(self) -> str: