Allow to configure the policy persistence mode, adding a new mode
which stores the entire policy (new default), supporting applications where it is desired to be bale to load the policy without having to instantiate an environment or recreate a corresponding policy object
This commit is contained in:
parent
86cca8ffc3
commit
a3dbe90515
@ -108,6 +108,8 @@ class ExperimentConfig:
|
||||
in this directory based on the run's experiment name"""
|
||||
persistence_enabled: bool = True
|
||||
"""Whether persistence is enabled, allowing files to be stored"""
|
||||
policy_persistence_mode: PolicyPersistence.Mode = PolicyPersistence.Mode.POLICY
|
||||
"""Controls the way in which the policy is persisted"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -220,7 +222,11 @@ class Experiment(ToStringMixin):
|
||||
|
||||
# initialize persistence
|
||||
additional_persistence = PersistenceGroup(*envs.persistence, enabled=use_persistence)
|
||||
policy_persistence = PolicyPersistence(additional_persistence, enabled=use_persistence)
|
||||
policy_persistence = PolicyPersistence(
|
||||
additional_persistence,
|
||||
enabled=use_persistence,
|
||||
mode=self.config.policy_persistence_mode,
|
||||
)
|
||||
if use_persistence:
|
||||
log.info(f"Persistence directory: {os.path.abspath(persistence_dir)}")
|
||||
self.save(persistence_dir)
|
||||
|
@ -57,29 +57,63 @@ class PersistenceGroup(Persistence):
|
||||
|
||||
|
||||
class PolicyPersistence:
|
||||
FILENAME = "policy.dat"
|
||||
"""Handles persistence of the policy."""
|
||||
|
||||
def __init__(self, additional_persistence: Persistence | None = None, enabled: bool = True):
|
||||
class Mode(Enum):
|
||||
POLICY_STATE_DICT = "policy_state_dict"
|
||||
"""Persist only the policy's state dictionary"""
|
||||
POLICY = "policy"
|
||||
"""Persist the entire policy. This is larger but has the advantage of the policy being loadable
|
||||
without requiring an environment to be instantiated.
|
||||
It has the potential disadvantage that upon breaking code changes in the policy implementation
|
||||
(e.g. renamed/moved class), it will no longer be loadable.
|
||||
"""
|
||||
|
||||
def get_filename(self) -> str:
|
||||
return self.value + ".pt"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
additional_persistence: Persistence | None = None,
|
||||
enabled: bool = True,
|
||||
mode: Mode = Mode.POLICY,
|
||||
):
|
||||
""":param additional_persistence: a persistence instance which is to be invoked whenever
|
||||
this object is used to persist/restore data
|
||||
:param enabled: whether persistence is enabled (restoration is always enabled)
|
||||
:param mode: the persistence mode
|
||||
"""
|
||||
self.additional_persistence = additional_persistence
|
||||
self.enabled = enabled
|
||||
self.mode = mode
|
||||
|
||||
def persist(self, policy: torch.nn.Module, world: World) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
path = world.persist_path(self.FILENAME)
|
||||
log.info(f"Saving policy in {path}")
|
||||
torch.save(policy.state_dict(), path)
|
||||
path = world.persist_path(self.mode.get_filename())
|
||||
match self.mode:
|
||||
case self.Mode.POLICY_STATE_DICT:
|
||||
log.info(f"Saving policy state dictionary in {path}")
|
||||
torch.save(policy.state_dict(), path)
|
||||
case self.Mode.POLICY:
|
||||
log.info(f"Saving policy object in {path}")
|
||||
torch.save(policy, path)
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
if self.additional_persistence is not None:
|
||||
self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world)
|
||||
|
||||
def restore(self, policy: torch.nn.Module, world: World, device: "TDevice") -> None:
|
||||
path = world.restore_path(self.FILENAME)
|
||||
path = world.restore_path(self.mode.get_filename())
|
||||
log.info(f"Restoring policy from {path}")
|
||||
state_dict = torch.load(path, map_location=device)
|
||||
match self.mode:
|
||||
case self.Mode.POLICY_STATE_DICT:
|
||||
state_dict = torch.load(path, map_location=device)
|
||||
case self.Mode.POLICY:
|
||||
loaded_policy: torch.nn.Module = torch.load(path, map_location=device)
|
||||
state_dict = loaded_policy.state_dict()
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
policy.load_state_dict(state_dict)
|
||||
if self.additional_persistence is not None:
|
||||
self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)
|
||||
|
Loading…
x
Reference in New Issue
Block a user