diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 70c7be7..0ef81a3 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -108,6 +108,8 @@ class ExperimentConfig: in this directory based on the run's experiment name""" persistence_enabled: bool = True """Whether persistence is enabled, allowing files to be stored""" + policy_persistence_mode: PolicyPersistence.Mode = PolicyPersistence.Mode.POLICY + """Controls the way in which the policy is persisted""" @dataclass @@ -220,7 +222,11 @@ class Experiment(ToStringMixin): # initialize persistence additional_persistence = PersistenceGroup(*envs.persistence, enabled=use_persistence) - policy_persistence = PolicyPersistence(additional_persistence, enabled=use_persistence) + policy_persistence = PolicyPersistence( + additional_persistence, + enabled=use_persistence, + mode=self.config.policy_persistence_mode, + ) if use_persistence: log.info(f"Persistence directory: {os.path.abspath(persistence_dir)}") self.save(persistence_dir) diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index 951ca08..686a26c 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -57,29 +57,63 @@ class PersistenceGroup(Persistence): class PolicyPersistence: - FILENAME = "policy.dat" + """Handles persistence of the policy.""" - def __init__(self, additional_persistence: Persistence | None = None, enabled: bool = True): + class Mode(Enum): + POLICY_STATE_DICT = "policy_state_dict" + """Persist only the policy's state dictionary""" + POLICY = "policy" + """Persist the entire policy. This is larger but has the advantage of the policy being loadable + without requiring an environment to be instantiated. + It has the potential disadvantage that upon breaking code changes in the policy implementation + (e.g. renamed/moved class), it will no longer be loadable. + """ + + def get_filename(self) -> str: + return self.value + ".pt" + + def __init__( + self, + additional_persistence: Persistence | None = None, + enabled: bool = True, + mode: Mode = Mode.POLICY, + ): """:param additional_persistence: a persistence instance which is to be invoked whenever this object is used to persist/restore data :param enabled: whether persistence is enabled (restoration is always enabled) + :param mode: the persistence mode """ self.additional_persistence = additional_persistence self.enabled = enabled + self.mode = mode def persist(self, policy: torch.nn.Module, world: World) -> None: if not self.enabled: return - path = world.persist_path(self.FILENAME) - log.info(f"Saving policy in {path}") - torch.save(policy.state_dict(), path) + path = world.persist_path(self.mode.get_filename()) + match self.mode: + case self.Mode.POLICY_STATE_DICT: + log.info(f"Saving policy state dictionary in {path}") + torch.save(policy.state_dict(), path) + case self.Mode.POLICY: + log.info(f"Saving policy object in {path}") + torch.save(policy, path) + case _: + raise NotImplementedError if self.additional_persistence is not None: self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world) def restore(self, policy: torch.nn.Module, world: World, device: "TDevice") -> None: - path = world.restore_path(self.FILENAME) + path = world.restore_path(self.mode.get_filename()) log.info(f"Restoring policy from {path}") - state_dict = torch.load(path, map_location=device) + match self.mode: + case self.Mode.POLICY_STATE_DICT: + state_dict = torch.load(path, map_location=device) + case self.Mode.POLICY: + loaded_policy: torch.nn.Module = torch.load(path, map_location=device) + state_dict = loaded_policy.state_dict() + case _: + raise NotImplementedError policy.load_state_dict(state_dict) if self.additional_persistence is not None: self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)