Support obs_rms persistence for MuJoCo by adding a general mechanism

for attaching persistence to Environments instances
This commit is contained in:
Dominik Jain 2023-10-12 15:01:49 +02:00
parent f6d49774a2
commit 3691ed2abc
5 changed files with 90 additions and 26 deletions

View File

@ -1,16 +1,22 @@
import pickle
import warnings
import logging
import gymnasium as gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
from tianshou.highlevel.persistence import Persistence, RestoreEvent, PersistEvent
from tianshou.highlevel.world import World
try:
import envpool
except ImportError:
envpool = None
log = logging.getLogger(__name__)
def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool):
"""Wrapper function for Mujoco env.
@ -40,6 +46,29 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
return env, train_envs, test_envs
class MujocoEnvObsRmsPersistence(Persistence):
FILENAME = "env_obs_rms.pkl"
def persist(self, event: PersistEvent, world: World) -> None:
if event != PersistEvent.PERSIST_POLICY:
return
obs_rms = world.envs.train_envs.get_obs_rms()
path = world.persist_path(self.FILENAME)
log.info(f"Saving environment obs_rms value to {path}")
with open(path, "wb") as f:
pickle.dump(obs_rms, f)
def restore(self, event: RestoreEvent, world: World):
if event != RestoreEvent.RESTORE_POLICY:
return
path = world.restore_path(self.FILENAME)
log.info(f"Restoring environment obs_rms value from {path}")
with open(path, "rb") as f:
obs_rms = pickle.load(f)
world.envs.train_envs.set_obs_rms(obs_rms)
world.envs.test_envs.set_obs_rms(obs_rms)
class MujocoEnvFactory(EnvFactory):
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig):
self.task = task
@ -54,4 +83,6 @@ class MujocoEnvFactory(EnvFactory):
num_test_envs=self.sampling_config.num_test_envs,
obs_norm=True,
)
return ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
envs.set_persistence(MujocoEnvObsRmsPersistence())
return envs

View File

@ -6,7 +6,7 @@ from typing import Any, TypeAlias
import gymnasium as gym
from tianshou.env import BaseVectorEnv
from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.highlevel.persistence import PersistableConfigProtocol, Persistence
from tianshou.utils.net.common import TActionShape
from tianshou.utils.string import ToStringMixin
@ -37,6 +37,7 @@ class Environments(ToStringMixin, ABC):
self.env = env
self.train_envs = train_envs
self.test_envs = test_envs
self.persistence = []
def _tostring_includes(self) -> list[str]:
return []
@ -50,6 +51,9 @@ class Environments(ToStringMixin, ABC):
"state_shape": self.get_observation_shape(),
}
def set_persistence(self, *p: Persistence):
self.persistence = p
@abstractmethod
def get_action_shape(self) -> TActionShape:
pass

View File

@ -1,3 +1,4 @@
import os
import logging
from abc import abstractmethod
from collections.abc import Callable, Sequence
@ -64,7 +65,7 @@ from tianshou.highlevel.params.policy_params import (
TRPOParams,
)
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence
from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence, PersistenceGroup
from tianshou.highlevel.trainer import (
TrainerCallbacks,
TrainerEpochCallbackTest,
@ -127,7 +128,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
def _build_config_dict(self) -> dict:
return {
# TODO
"experiment": self.pprints()
}
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
@ -143,14 +144,21 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
if experiment_name is None:
experiment_name = datetime_tag()
log.info(f"Working directory: {os.getcwd()}")
self._set_seed()
# create environments
envs = self.env_factory(self.env_config)
policy_persistence = PolicyPersistence()
log.info(f"Created {envs}")
# initialize persistence
additional_persistence = PersistenceGroup(*envs.persistence)
policy_persistence = PolicyPersistence(additional_persistence)
# initialize logger
full_config = self._build_config_dict()
full_config.update(envs.info())
logger = self.logger_factory.create_logger(
log_name=experiment_name,
run_id=logger_run_id,
@ -170,12 +178,13 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
train_collector=train_collector,
test_collector=test_collector,
logger=logger,
restore_directory=self.config.policy_restore_directory
)
if self.config.policy_restore_directory:
policy_persistence.restore(
policy,
self.config.policy_restore_directory,
world,
self.config.device,
)

View File

@ -1,7 +1,8 @@
import logging
import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable
from enum import Enum
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable, Callable
import torch
@ -23,16 +24,23 @@ class PersistableConfigProtocol(Protocol):
pass
class Persistence(ABC):
def path(self, world: World, filename: str) -> str:
return os.path.join(world.directory, filename)
class PersistEvent(Enum):
PERSIST_POLICY = "persist_policy"
"""Policy neural network is persisted (new best found)"""
class RestoreEvent(Enum):
RESTORE_POLICY = "restore_policy"
"""Policy neural network parameters are restored"""
class Persistence(ABC):
@abstractmethod
def persist(self, world: World) -> None:
def persist(self, event: PersistEvent, world: World) -> None:
pass
@abstractmethod
def restore(self, world: World):
def restore(self, event: RestoreEvent, world: World):
pass
@ -40,31 +48,38 @@ class PersistenceGroup(Persistence):
def __init__(self, *p: Persistence):
self.items = p
def persist(self, world: World) -> None:
def persist(self, event: PersistEvent, world: World) -> None:
for item in self.items:
item.persist(world)
item.persist(event, world)
def restore(self, world: World):
def restore(self, event: RestoreEvent, world: World):
for item in self.items:
item.restore(world)
item.restore(event, world)
class PolicyPersistence:
FILENAME = "policy.dat"
def persist(self, policy: torch.nn.Module, directory: str) -> None:
path = os.path.join(directory, self.FILENAME)
def __init__(self, additional_persistence: Persistence | None = None):
self.additional_persistence = additional_persistence
def persist(self, policy: torch.nn.Module, world: World) -> None:
path = world.persist_path(self.FILENAME)
log.info(f"Saving policy in {path}")
torch.save(policy.state_dict(), path)
if self.additional_persistence is not None:
self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world)
def restore(self, policy: torch.nn.Module, directory: str, device: "TDevice") -> None:
path = os.path.join(directory, self.FILENAME)
def restore(self, policy: torch.nn.Module, world: World, device: "TDevice") -> None:
path = world.restore_path(self.FILENAME)
log.info(f"Restoring policy from {path}")
state_dict = torch.load(path, map_location=device)
policy.load_state_dict(state_dict)
if self.additional_persistence is not None:
self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)
def get_save_best_fn(self, world):
def get_save_best_fn(self, world) -> Callable[[torch.nn.Module], None]:
def save_best_fn(pol: torch.nn.Module) -> None:
self.persist(pol, world.directory)
self.persist(pol, world)
return save_best_fn

View File

@ -17,11 +17,16 @@ class World:
train_collector: "Collector"
test_collector: "Collector"
logger: "Logger"
restore_directory: str
trainer: Optional["BaseTrainer"] = None
@property
def directory(self):
def persist_directory(self):
return self.logger.log_path
def path(self, filename: str) -> str:
return os.path.join(self.directory, filename)
def persist_path(self, filename: str) -> str:
return os.path.join(self.persist_directory, filename)
def restore_path(self, filename: str) -> str:
return os.path.join(self.restore_directory, filename)