From 3691ed2abce2864453b208c82829d23f59186425 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 12 Oct 2023 15:01:49 +0200 Subject: [PATCH] Support obs_rms persistence for MuJoCo by adding a general mechanism for attaching persistence to Environments instances --- examples/mujoco/mujoco_env.py | 33 +++++++++++++++++++++- tianshou/highlevel/env.py | 6 +++- tianshou/highlevel/experiment.py | 19 +++++++++---- tianshou/highlevel/persistence.py | 47 ++++++++++++++++++++----------- tianshou/highlevel/world.py | 11 ++++++-- 5 files changed, 90 insertions(+), 26 deletions(-) diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index 620d6dd..fe08529 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -1,16 +1,22 @@ +import pickle import warnings +import logging import gymnasium as gym from tianshou.env import ShmemVectorEnv, VectorEnvNormObs from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory +from tianshou.highlevel.persistence import Persistence, RestoreEvent, PersistEvent +from tianshou.highlevel.world import World try: import envpool except ImportError: envpool = None +log = logging.getLogger(__name__) + def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool): """Wrapper function for Mujoco env. @@ -40,6 +46,29 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in return env, train_envs, test_envs +class MujocoEnvObsRmsPersistence(Persistence): + FILENAME = "env_obs_rms.pkl" + + def persist(self, event: PersistEvent, world: World) -> None: + if event != PersistEvent.PERSIST_POLICY: + return + obs_rms = world.envs.train_envs.get_obs_rms() + path = world.persist_path(self.FILENAME) + log.info(f"Saving environment obs_rms value to {path}") + with open(path, "wb") as f: + pickle.dump(obs_rms, f) + + def restore(self, event: RestoreEvent, world: World): + if event != RestoreEvent.RESTORE_POLICY: + return + path = world.restore_path(self.FILENAME) + log.info(f"Restoring environment obs_rms value from {path}") + with open(path, "rb") as f: + obs_rms = pickle.load(f) + world.envs.train_envs.set_obs_rms(obs_rms) + world.envs.test_envs.set_obs_rms(obs_rms) + + class MujocoEnvFactory(EnvFactory): def __init__(self, task: str, seed: int, sampling_config: SamplingConfig): self.task = task @@ -54,4 +83,6 @@ class MujocoEnvFactory(EnvFactory): num_test_envs=self.sampling_config.num_test_envs, obs_norm=True, ) - return ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) + envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) + envs.set_persistence(MujocoEnvObsRmsPersistence()) + return envs diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 8f22236..179f941 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -6,7 +6,7 @@ from typing import Any, TypeAlias import gymnasium as gym from tianshou.env import BaseVectorEnv -from tianshou.highlevel.persistence import PersistableConfigProtocol +from tianshou.highlevel.persistence import PersistableConfigProtocol, Persistence from tianshou.utils.net.common import TActionShape from tianshou.utils.string import ToStringMixin @@ -37,6 +37,7 @@ class Environments(ToStringMixin, ABC): self.env = env self.train_envs = train_envs self.test_envs = test_envs + self.persistence = [] def _tostring_includes(self) -> list[str]: return [] @@ -50,6 +51,9 @@ class Environments(ToStringMixin, ABC): "state_shape": self.get_observation_shape(), } + def set_persistence(self, *p: Persistence): + self.persistence = p + @abstractmethod def get_action_shape(self) -> TActionShape: pass diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index a3cd264..0c9a91d 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -1,3 +1,4 @@ +import os import logging from abc import abstractmethod from collections.abc import Callable, Sequence @@ -64,7 +65,7 @@ from tianshou.highlevel.params.policy_params import ( TRPOParams, ) from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory -from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence +from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence, PersistenceGroup from tianshou.highlevel.trainer import ( TrainerCallbacks, TrainerEpochCallbackTest, @@ -127,7 +128,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): def _build_config_dict(self) -> dict: return { - # TODO + "experiment": self.pprints() } def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None: @@ -143,14 +144,21 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): if experiment_name is None: experiment_name = datetime_tag() + log.info(f"Working directory: {os.getcwd()}") + self._set_seed() + + # create environments envs = self.env_factory(self.env_config) - policy_persistence = PolicyPersistence() log.info(f"Created {envs}") + # initialize persistence + additional_persistence = PersistenceGroup(*envs.persistence) + policy_persistence = PolicyPersistence(additional_persistence) + + # initialize logger full_config = self._build_config_dict() full_config.update(envs.info()) - logger = self.logger_factory.create_logger( log_name=experiment_name, run_id=logger_run_id, @@ -170,12 +178,13 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): train_collector=train_collector, test_collector=test_collector, logger=logger, + restore_directory=self.config.policy_restore_directory ) if self.config.policy_restore_directory: policy_persistence.restore( policy, - self.config.policy_restore_directory, + world, self.config.device, ) diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index 8617c82..4e4f606 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -1,7 +1,8 @@ import logging import os from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable +from enum import Enum +from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable, Callable import torch @@ -23,16 +24,23 @@ class PersistableConfigProtocol(Protocol): pass -class Persistence(ABC): - def path(self, world: World, filename: str) -> str: - return os.path.join(world.directory, filename) +class PersistEvent(Enum): + PERSIST_POLICY = "persist_policy" + """Policy neural network is persisted (new best found)""" + +class RestoreEvent(Enum): + RESTORE_POLICY = "restore_policy" + """Policy neural network parameters are restored""" + + +class Persistence(ABC): @abstractmethod - def persist(self, world: World) -> None: + def persist(self, event: PersistEvent, world: World) -> None: pass @abstractmethod - def restore(self, world: World): + def restore(self, event: RestoreEvent, world: World): pass @@ -40,31 +48,38 @@ class PersistenceGroup(Persistence): def __init__(self, *p: Persistence): self.items = p - def persist(self, world: World) -> None: + def persist(self, event: PersistEvent, world: World) -> None: for item in self.items: - item.persist(world) + item.persist(event, world) - def restore(self, world: World): + def restore(self, event: RestoreEvent, world: World): for item in self.items: - item.restore(world) + item.restore(event, world) class PolicyPersistence: FILENAME = "policy.dat" - def persist(self, policy: torch.nn.Module, directory: str) -> None: - path = os.path.join(directory, self.FILENAME) + def __init__(self, additional_persistence: Persistence | None = None): + self.additional_persistence = additional_persistence + + def persist(self, policy: torch.nn.Module, world: World) -> None: + path = world.persist_path(self.FILENAME) log.info(f"Saving policy in {path}") torch.save(policy.state_dict(), path) + if self.additional_persistence is not None: + self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world) - def restore(self, policy: torch.nn.Module, directory: str, device: "TDevice") -> None: - path = os.path.join(directory, self.FILENAME) + def restore(self, policy: torch.nn.Module, world: World, device: "TDevice") -> None: + path = world.restore_path(self.FILENAME) log.info(f"Restoring policy from {path}") state_dict = torch.load(path, map_location=device) policy.load_state_dict(state_dict) + if self.additional_persistence is not None: + self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world) - def get_save_best_fn(self, world): + def get_save_best_fn(self, world) -> Callable[[torch.nn.Module], None]: def save_best_fn(pol: torch.nn.Module) -> None: - self.persist(pol, world.directory) + self.persist(pol, world) return save_best_fn diff --git a/tianshou/highlevel/world.py b/tianshou/highlevel/world.py index 5d69a07..c5cb369 100644 --- a/tianshou/highlevel/world.py +++ b/tianshou/highlevel/world.py @@ -17,11 +17,16 @@ class World: train_collector: "Collector" test_collector: "Collector" logger: "Logger" + restore_directory: str trainer: Optional["BaseTrainer"] = None + @property - def directory(self): + def persist_directory(self): return self.logger.log_path - def path(self, filename: str) -> str: - return os.path.join(self.directory, filename) + def persist_path(self, filename: str) -> str: + return os.path.join(self.persist_directory, filename) + + def restore_path(self, filename: str) -> str: + return os.path.join(self.restore_directory, filename)