Allow to configure the policy persistence mode, adding a new mode
which stores the entire policy (new default), supporting applications where it is desired to be bale to load the policy without having to instantiate an environment or recreate a corresponding policy object
This commit is contained in:
parent
86cca8ffc3
commit
a3dbe90515
@ -108,6 +108,8 @@ class ExperimentConfig:
|
|||||||
in this directory based on the run's experiment name"""
|
in this directory based on the run's experiment name"""
|
||||||
persistence_enabled: bool = True
|
persistence_enabled: bool = True
|
||||||
"""Whether persistence is enabled, allowing files to be stored"""
|
"""Whether persistence is enabled, allowing files to be stored"""
|
||||||
|
policy_persistence_mode: PolicyPersistence.Mode = PolicyPersistence.Mode.POLICY
|
||||||
|
"""Controls the way in which the policy is persisted"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -220,7 +222,11 @@ class Experiment(ToStringMixin):
|
|||||||
|
|
||||||
# initialize persistence
|
# initialize persistence
|
||||||
additional_persistence = PersistenceGroup(*envs.persistence, enabled=use_persistence)
|
additional_persistence = PersistenceGroup(*envs.persistence, enabled=use_persistence)
|
||||||
policy_persistence = PolicyPersistence(additional_persistence, enabled=use_persistence)
|
policy_persistence = PolicyPersistence(
|
||||||
|
additional_persistence,
|
||||||
|
enabled=use_persistence,
|
||||||
|
mode=self.config.policy_persistence_mode,
|
||||||
|
)
|
||||||
if use_persistence:
|
if use_persistence:
|
||||||
log.info(f"Persistence directory: {os.path.abspath(persistence_dir)}")
|
log.info(f"Persistence directory: {os.path.abspath(persistence_dir)}")
|
||||||
self.save(persistence_dir)
|
self.save(persistence_dir)
|
||||||
|
@ -57,29 +57,63 @@ class PersistenceGroup(Persistence):
|
|||||||
|
|
||||||
|
|
||||||
class PolicyPersistence:
|
class PolicyPersistence:
|
||||||
FILENAME = "policy.dat"
|
"""Handles persistence of the policy."""
|
||||||
|
|
||||||
def __init__(self, additional_persistence: Persistence | None = None, enabled: bool = True):
|
class Mode(Enum):
|
||||||
|
POLICY_STATE_DICT = "policy_state_dict"
|
||||||
|
"""Persist only the policy's state dictionary"""
|
||||||
|
POLICY = "policy"
|
||||||
|
"""Persist the entire policy. This is larger but has the advantage of the policy being loadable
|
||||||
|
without requiring an environment to be instantiated.
|
||||||
|
It has the potential disadvantage that upon breaking code changes in the policy implementation
|
||||||
|
(e.g. renamed/moved class), it will no longer be loadable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_filename(self) -> str:
|
||||||
|
return self.value + ".pt"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
additional_persistence: Persistence | None = None,
|
||||||
|
enabled: bool = True,
|
||||||
|
mode: Mode = Mode.POLICY,
|
||||||
|
):
|
||||||
""":param additional_persistence: a persistence instance which is to be invoked whenever
|
""":param additional_persistence: a persistence instance which is to be invoked whenever
|
||||||
this object is used to persist/restore data
|
this object is used to persist/restore data
|
||||||
:param enabled: whether persistence is enabled (restoration is always enabled)
|
:param enabled: whether persistence is enabled (restoration is always enabled)
|
||||||
|
:param mode: the persistence mode
|
||||||
"""
|
"""
|
||||||
self.additional_persistence = additional_persistence
|
self.additional_persistence = additional_persistence
|
||||||
self.enabled = enabled
|
self.enabled = enabled
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
def persist(self, policy: torch.nn.Module, world: World) -> None:
|
def persist(self, policy: torch.nn.Module, world: World) -> None:
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return
|
return
|
||||||
path = world.persist_path(self.FILENAME)
|
path = world.persist_path(self.mode.get_filename())
|
||||||
log.info(f"Saving policy in {path}")
|
match self.mode:
|
||||||
torch.save(policy.state_dict(), path)
|
case self.Mode.POLICY_STATE_DICT:
|
||||||
|
log.info(f"Saving policy state dictionary in {path}")
|
||||||
|
torch.save(policy.state_dict(), path)
|
||||||
|
case self.Mode.POLICY:
|
||||||
|
log.info(f"Saving policy object in {path}")
|
||||||
|
torch.save(policy, path)
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError
|
||||||
if self.additional_persistence is not None:
|
if self.additional_persistence is not None:
|
||||||
self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world)
|
self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world)
|
||||||
|
|
||||||
def restore(self, policy: torch.nn.Module, world: World, device: "TDevice") -> None:
|
def restore(self, policy: torch.nn.Module, world: World, device: "TDevice") -> None:
|
||||||
path = world.restore_path(self.FILENAME)
|
path = world.restore_path(self.mode.get_filename())
|
||||||
log.info(f"Restoring policy from {path}")
|
log.info(f"Restoring policy from {path}")
|
||||||
state_dict = torch.load(path, map_location=device)
|
match self.mode:
|
||||||
|
case self.Mode.POLICY_STATE_DICT:
|
||||||
|
state_dict = torch.load(path, map_location=device)
|
||||||
|
case self.Mode.POLICY:
|
||||||
|
loaded_policy: torch.nn.Module = torch.load(path, map_location=device)
|
||||||
|
state_dict = loaded_policy.state_dict()
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError
|
||||||
policy.load_state_dict(state_dict)
|
policy.load_state_dict(state_dict)
|
||||||
if self.additional_persistence is not None:
|
if self.additional_persistence is not None:
|
||||||
self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)
|
self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user