Support obs_rms persistence for MuJoCo by adding a general mechanism

for attaching persistence to Environments instances
This commit is contained in:
Dominik Jain 2023-10-12 15:01:49 +02:00
parent f6d49774a2
commit 3691ed2abc
5 changed files with 90 additions and 26 deletions

View File

@ -1,16 +1,22 @@
import pickle
import warnings import warnings
import logging
import gymnasium as gym import gymnasium as gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
from tianshou.highlevel.persistence import Persistence, RestoreEvent, PersistEvent
from tianshou.highlevel.world import World
try: try:
import envpool import envpool
except ImportError: except ImportError:
envpool = None envpool = None
log = logging.getLogger(__name__)
def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool): def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool):
"""Wrapper function for Mujoco env. """Wrapper function for Mujoco env.
@ -40,6 +46,29 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
return env, train_envs, test_envs return env, train_envs, test_envs
class MujocoEnvObsRmsPersistence(Persistence):
FILENAME = "env_obs_rms.pkl"
def persist(self, event: PersistEvent, world: World) -> None:
if event != PersistEvent.PERSIST_POLICY:
return
obs_rms = world.envs.train_envs.get_obs_rms()
path = world.persist_path(self.FILENAME)
log.info(f"Saving environment obs_rms value to {path}")
with open(path, "wb") as f:
pickle.dump(obs_rms, f)
def restore(self, event: RestoreEvent, world: World):
if event != RestoreEvent.RESTORE_POLICY:
return
path = world.restore_path(self.FILENAME)
log.info(f"Restoring environment obs_rms value from {path}")
with open(path, "rb") as f:
obs_rms = pickle.load(f)
world.envs.train_envs.set_obs_rms(obs_rms)
world.envs.test_envs.set_obs_rms(obs_rms)
class MujocoEnvFactory(EnvFactory): class MujocoEnvFactory(EnvFactory):
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig): def __init__(self, task: str, seed: int, sampling_config: SamplingConfig):
self.task = task self.task = task
@ -54,4 +83,6 @@ class MujocoEnvFactory(EnvFactory):
num_test_envs=self.sampling_config.num_test_envs, num_test_envs=self.sampling_config.num_test_envs,
obs_norm=True, obs_norm=True,
) )
return ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
envs.set_persistence(MujocoEnvObsRmsPersistence())
return envs

View File

@ -6,7 +6,7 @@ from typing import Any, TypeAlias
import gymnasium as gym import gymnasium as gym
from tianshou.env import BaseVectorEnv from tianshou.env import BaseVectorEnv
from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.highlevel.persistence import PersistableConfigProtocol, Persistence
from tianshou.utils.net.common import TActionShape from tianshou.utils.net.common import TActionShape
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
@ -37,6 +37,7 @@ class Environments(ToStringMixin, ABC):
self.env = env self.env = env
self.train_envs = train_envs self.train_envs = train_envs
self.test_envs = test_envs self.test_envs = test_envs
self.persistence = []
def _tostring_includes(self) -> list[str]: def _tostring_includes(self) -> list[str]:
return [] return []
@ -50,6 +51,9 @@ class Environments(ToStringMixin, ABC):
"state_shape": self.get_observation_shape(), "state_shape": self.get_observation_shape(),
} }
def set_persistence(self, *p: Persistence):
self.persistence = p
@abstractmethod @abstractmethod
def get_action_shape(self) -> TActionShape: def get_action_shape(self) -> TActionShape:
pass pass

View File

@ -1,3 +1,4 @@
import os
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
@ -64,7 +65,7 @@ from tianshou.highlevel.params.policy_params import (
TRPOParams, TRPOParams,
) )
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence, PersistenceGroup
from tianshou.highlevel.trainer import ( from tianshou.highlevel.trainer import (
TrainerCallbacks, TrainerCallbacks,
TrainerEpochCallbackTest, TrainerEpochCallbackTest,
@ -127,7 +128,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
def _build_config_dict(self) -> dict: def _build_config_dict(self) -> dict:
return { return {
# TODO "experiment": self.pprints()
} }
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None: def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
@ -143,14 +144,21 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
if experiment_name is None: if experiment_name is None:
experiment_name = datetime_tag() experiment_name = datetime_tag()
log.info(f"Working directory: {os.getcwd()}")
self._set_seed() self._set_seed()
# create environments
envs = self.env_factory(self.env_config) envs = self.env_factory(self.env_config)
policy_persistence = PolicyPersistence()
log.info(f"Created {envs}") log.info(f"Created {envs}")
# initialize persistence
additional_persistence = PersistenceGroup(*envs.persistence)
policy_persistence = PolicyPersistence(additional_persistence)
# initialize logger
full_config = self._build_config_dict() full_config = self._build_config_dict()
full_config.update(envs.info()) full_config.update(envs.info())
logger = self.logger_factory.create_logger( logger = self.logger_factory.create_logger(
log_name=experiment_name, log_name=experiment_name,
run_id=logger_run_id, run_id=logger_run_id,
@ -170,12 +178,13 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
train_collector=train_collector, train_collector=train_collector,
test_collector=test_collector, test_collector=test_collector,
logger=logger, logger=logger,
restore_directory=self.config.policy_restore_directory
) )
if self.config.policy_restore_directory: if self.config.policy_restore_directory:
policy_persistence.restore( policy_persistence.restore(
policy, policy,
self.config.policy_restore_directory, world,
self.config.device, self.config.device,
) )

View File

@ -1,7 +1,8 @@
import logging import logging
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable from enum import Enum
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable, Callable
import torch import torch
@ -23,16 +24,23 @@ class PersistableConfigProtocol(Protocol):
pass pass
class Persistence(ABC): class PersistEvent(Enum):
def path(self, world: World, filename: str) -> str: PERSIST_POLICY = "persist_policy"
return os.path.join(world.directory, filename) """Policy neural network is persisted (new best found)"""
class RestoreEvent(Enum):
RESTORE_POLICY = "restore_policy"
"""Policy neural network parameters are restored"""
class Persistence(ABC):
@abstractmethod @abstractmethod
def persist(self, world: World) -> None: def persist(self, event: PersistEvent, world: World) -> None:
pass pass
@abstractmethod @abstractmethod
def restore(self, world: World): def restore(self, event: RestoreEvent, world: World):
pass pass
@ -40,31 +48,38 @@ class PersistenceGroup(Persistence):
def __init__(self, *p: Persistence): def __init__(self, *p: Persistence):
self.items = p self.items = p
def persist(self, world: World) -> None: def persist(self, event: PersistEvent, world: World) -> None:
for item in self.items: for item in self.items:
item.persist(world) item.persist(event, world)
def restore(self, world: World): def restore(self, event: RestoreEvent, world: World):
for item in self.items: for item in self.items:
item.restore(world) item.restore(event, world)
class PolicyPersistence: class PolicyPersistence:
FILENAME = "policy.dat" FILENAME = "policy.dat"
def persist(self, policy: torch.nn.Module, directory: str) -> None: def __init__(self, additional_persistence: Persistence | None = None):
path = os.path.join(directory, self.FILENAME) self.additional_persistence = additional_persistence
def persist(self, policy: torch.nn.Module, world: World) -> None:
path = world.persist_path(self.FILENAME)
log.info(f"Saving policy in {path}") log.info(f"Saving policy in {path}")
torch.save(policy.state_dict(), path) torch.save(policy.state_dict(), path)
if self.additional_persistence is not None:
self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world)
def restore(self, policy: torch.nn.Module, directory: str, device: "TDevice") -> None: def restore(self, policy: torch.nn.Module, world: World, device: "TDevice") -> None:
path = os.path.join(directory, self.FILENAME) path = world.restore_path(self.FILENAME)
log.info(f"Restoring policy from {path}") log.info(f"Restoring policy from {path}")
state_dict = torch.load(path, map_location=device) state_dict = torch.load(path, map_location=device)
policy.load_state_dict(state_dict) policy.load_state_dict(state_dict)
if self.additional_persistence is not None:
self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)
def get_save_best_fn(self, world): def get_save_best_fn(self, world) -> Callable[[torch.nn.Module], None]:
def save_best_fn(pol: torch.nn.Module) -> None: def save_best_fn(pol: torch.nn.Module) -> None:
self.persist(pol, world.directory) self.persist(pol, world)
return save_best_fn return save_best_fn

View File

@ -17,11 +17,16 @@ class World:
train_collector: "Collector" train_collector: "Collector"
test_collector: "Collector" test_collector: "Collector"
logger: "Logger" logger: "Logger"
restore_directory: str
trainer: Optional["BaseTrainer"] = None trainer: Optional["BaseTrainer"] = None
@property @property
def directory(self): def persist_directory(self):
return self.logger.log_path return self.logger.log_path
def path(self, filename: str) -> str: def persist_path(self, filename: str) -> str:
return os.path.join(self.directory, filename) return os.path.join(self.persist_directory, filename)
def restore_path(self, filename: str) -> str:
return os.path.join(self.restore_directory, filename)