Tianshou/tianshou/highlevel/persistence.py
Dominik Jain 3691ed2abc Support obs_rms persistence for MuJoCo by adding a general mechanism
for attaching persistence to Environments instances
2023-10-18 20:44:17 +02:00

86 lines
2.5 KiB
Python

import logging
import os
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable, Callable
import torch
from tianshou.highlevel.world import World
if TYPE_CHECKING:
from tianshou.highlevel.module.core import TDevice
log = logging.getLogger(__name__)
@runtime_checkable
class PersistableConfigProtocol(Protocol):
@classmethod
def load(cls, path: os.PathLike[str]) -> Self:
pass
def save(self, path: os.PathLike[str]) -> None:
pass
class PersistEvent(Enum):
PERSIST_POLICY = "persist_policy"
"""Policy neural network is persisted (new best found)"""
class RestoreEvent(Enum):
RESTORE_POLICY = "restore_policy"
"""Policy neural network parameters are restored"""
class Persistence(ABC):
@abstractmethod
def persist(self, event: PersistEvent, world: World) -> None:
pass
@abstractmethod
def restore(self, event: RestoreEvent, world: World):
pass
class PersistenceGroup(Persistence):
def __init__(self, *p: Persistence):
self.items = p
def persist(self, event: PersistEvent, world: World) -> None:
for item in self.items:
item.persist(event, world)
def restore(self, event: RestoreEvent, world: World):
for item in self.items:
item.restore(event, world)
class PolicyPersistence:
FILENAME = "policy.dat"
def __init__(self, additional_persistence: Persistence | None = None):
self.additional_persistence = additional_persistence
def persist(self, policy: torch.nn.Module, world: World) -> None:
path = world.persist_path(self.FILENAME)
log.info(f"Saving policy in {path}")
torch.save(policy.state_dict(), path)
if self.additional_persistence is not None:
self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world)
def restore(self, policy: torch.nn.Module, world: World, device: "TDevice") -> None:
path = world.restore_path(self.FILENAME)
log.info(f"Restoring policy from {path}")
state_dict = torch.load(path, map_location=device)
policy.load_state_dict(state_dict)
if self.additional_persistence is not None:
self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)
def get_save_best_fn(self, world) -> Callable[[torch.nn.Module], None]:
def save_best_fn(pol: torch.nn.Module) -> None:
self.persist(pol, world)
return save_best_fn