Allow to configure the policy persistence mode, adding a new mode

which stores the entire policy (new default), supporting applications
where it is desired to be bale to load the policy without having
to instantiate an environment or recreate a corresponding policy
object
This commit is contained in:
Dominik Jain 2023-10-26 13:19:33 +02:00
parent 86cca8ffc3
commit a3dbe90515
2 changed files with 48 additions and 8 deletions

View File

@ -108,6 +108,8 @@ class ExperimentConfig:
in this directory based on the run's experiment name"""
persistence_enabled: bool = True
"""Whether persistence is enabled, allowing files to be stored"""
policy_persistence_mode: PolicyPersistence.Mode = PolicyPersistence.Mode.POLICY
"""Controls the way in which the policy is persisted"""
@dataclass
@ -220,7 +222,11 @@ class Experiment(ToStringMixin):
# initialize persistence
additional_persistence = PersistenceGroup(*envs.persistence, enabled=use_persistence)
policy_persistence = PolicyPersistence(additional_persistence, enabled=use_persistence)
policy_persistence = PolicyPersistence(
additional_persistence,
enabled=use_persistence,
mode=self.config.policy_persistence_mode,
)
if use_persistence:
log.info(f"Persistence directory: {os.path.abspath(persistence_dir)}")
self.save(persistence_dir)

View File

@ -57,29 +57,63 @@ class PersistenceGroup(Persistence):
class PolicyPersistence:
FILENAME = "policy.dat"
"""Handles persistence of the policy."""
def __init__(self, additional_persistence: Persistence | None = None, enabled: bool = True):
class Mode(Enum):
POLICY_STATE_DICT = "policy_state_dict"
"""Persist only the policy's state dictionary"""
POLICY = "policy"
"""Persist the entire policy. This is larger but has the advantage of the policy being loadable
without requiring an environment to be instantiated.
It has the potential disadvantage that upon breaking code changes in the policy implementation
(e.g. renamed/moved class), it will no longer be loadable.
"""
def get_filename(self) -> str:
return self.value + ".pt"
def __init__(
self,
additional_persistence: Persistence | None = None,
enabled: bool = True,
mode: Mode = Mode.POLICY,
):
""":param additional_persistence: a persistence instance which is to be invoked whenever
this object is used to persist/restore data
:param enabled: whether persistence is enabled (restoration is always enabled)
:param mode: the persistence mode
"""
self.additional_persistence = additional_persistence
self.enabled = enabled
self.mode = mode
def persist(self, policy: torch.nn.Module, world: World) -> None:
if not self.enabled:
return
path = world.persist_path(self.FILENAME)
log.info(f"Saving policy in {path}")
torch.save(policy.state_dict(), path)
path = world.persist_path(self.mode.get_filename())
match self.mode:
case self.Mode.POLICY_STATE_DICT:
log.info(f"Saving policy state dictionary in {path}")
torch.save(policy.state_dict(), path)
case self.Mode.POLICY:
log.info(f"Saving policy object in {path}")
torch.save(policy, path)
case _:
raise NotImplementedError
if self.additional_persistence is not None:
self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world)
def restore(self, policy: torch.nn.Module, world: World, device: "TDevice") -> None:
path = world.restore_path(self.FILENAME)
path = world.restore_path(self.mode.get_filename())
log.info(f"Restoring policy from {path}")
state_dict = torch.load(path, map_location=device)
match self.mode:
case self.Mode.POLICY_STATE_DICT:
state_dict = torch.load(path, map_location=device)
case self.Mode.POLICY:
loaded_policy: torch.nn.Module = torch.load(path, map_location=device)
state_dict = loaded_policy.state_dict()
case _:
raise NotImplementedError
policy.load_state_dict(state_dict)
if self.additional_persistence is not None:
self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)