Support obs_rms persistence for MuJoCo by adding a general mechanism
for attaching persistence to Environments instances
This commit is contained in:
parent
f6d49774a2
commit
3691ed2abc
@ -1,16 +1,22 @@
|
||||
import pickle
|
||||
import warnings
|
||||
import logging
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
||||
from tianshou.highlevel.persistence import Persistence, RestoreEvent, PersistEvent
|
||||
from tianshou.highlevel.world import World
|
||||
|
||||
try:
|
||||
import envpool
|
||||
except ImportError:
|
||||
envpool = None
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool):
|
||||
"""Wrapper function for Mujoco env.
|
||||
@ -40,6 +46,29 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
|
||||
return env, train_envs, test_envs
|
||||
|
||||
|
||||
class MujocoEnvObsRmsPersistence(Persistence):
|
||||
FILENAME = "env_obs_rms.pkl"
|
||||
|
||||
def persist(self, event: PersistEvent, world: World) -> None:
|
||||
if event != PersistEvent.PERSIST_POLICY:
|
||||
return
|
||||
obs_rms = world.envs.train_envs.get_obs_rms()
|
||||
path = world.persist_path(self.FILENAME)
|
||||
log.info(f"Saving environment obs_rms value to {path}")
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(obs_rms, f)
|
||||
|
||||
def restore(self, event: RestoreEvent, world: World):
|
||||
if event != RestoreEvent.RESTORE_POLICY:
|
||||
return
|
||||
path = world.restore_path(self.FILENAME)
|
||||
log.info(f"Restoring environment obs_rms value from {path}")
|
||||
with open(path, "rb") as f:
|
||||
obs_rms = pickle.load(f)
|
||||
world.envs.train_envs.set_obs_rms(obs_rms)
|
||||
world.envs.test_envs.set_obs_rms(obs_rms)
|
||||
|
||||
|
||||
class MujocoEnvFactory(EnvFactory):
|
||||
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig):
|
||||
self.task = task
|
||||
@ -54,4 +83,6 @@ class MujocoEnvFactory(EnvFactory):
|
||||
num_test_envs=self.sampling_config.num_test_envs,
|
||||
obs_norm=True,
|
||||
)
|
||||
return ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||
envs.set_persistence(MujocoEnvObsRmsPersistence())
|
||||
return envs
|
||||
|
@ -6,7 +6,7 @@ from typing import Any, TypeAlias
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol, Persistence
|
||||
from tianshou.utils.net.common import TActionShape
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
@ -37,6 +37,7 @@ class Environments(ToStringMixin, ABC):
|
||||
self.env = env
|
||||
self.train_envs = train_envs
|
||||
self.test_envs = test_envs
|
||||
self.persistence = []
|
||||
|
||||
def _tostring_includes(self) -> list[str]:
|
||||
return []
|
||||
@ -50,6 +51,9 @@ class Environments(ToStringMixin, ABC):
|
||||
"state_shape": self.get_observation_shape(),
|
||||
}
|
||||
|
||||
def set_persistence(self, *p: Persistence):
|
||||
self.persistence = p
|
||||
|
||||
@abstractmethod
|
||||
def get_action_shape(self) -> TActionShape:
|
||||
pass
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Sequence
|
||||
@ -64,7 +65,7 @@ from tianshou.highlevel.params.policy_params import (
|
||||
TRPOParams,
|
||||
)
|
||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence, PersistenceGroup
|
||||
from tianshou.highlevel.trainer import (
|
||||
TrainerCallbacks,
|
||||
TrainerEpochCallbackTest,
|
||||
@ -127,7 +128,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
|
||||
def _build_config_dict(self) -> dict:
|
||||
return {
|
||||
# TODO
|
||||
"experiment": self.pprints()
|
||||
}
|
||||
|
||||
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
|
||||
@ -143,14 +144,21 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
if experiment_name is None:
|
||||
experiment_name = datetime_tag()
|
||||
|
||||
log.info(f"Working directory: {os.getcwd()}")
|
||||
|
||||
self._set_seed()
|
||||
|
||||
# create environments
|
||||
envs = self.env_factory(self.env_config)
|
||||
policy_persistence = PolicyPersistence()
|
||||
log.info(f"Created {envs}")
|
||||
|
||||
# initialize persistence
|
||||
additional_persistence = PersistenceGroup(*envs.persistence)
|
||||
policy_persistence = PolicyPersistence(additional_persistence)
|
||||
|
||||
# initialize logger
|
||||
full_config = self._build_config_dict()
|
||||
full_config.update(envs.info())
|
||||
|
||||
logger = self.logger_factory.create_logger(
|
||||
log_name=experiment_name,
|
||||
run_id=logger_run_id,
|
||||
@ -170,12 +178,13 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
train_collector=train_collector,
|
||||
test_collector=test_collector,
|
||||
logger=logger,
|
||||
restore_directory=self.config.policy_restore_directory
|
||||
)
|
||||
|
||||
if self.config.policy_restore_directory:
|
||||
policy_persistence.restore(
|
||||
policy,
|
||||
self.config.policy_restore_directory,
|
||||
world,
|
||||
self.config.device,
|
||||
)
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable, Callable
|
||||
|
||||
import torch
|
||||
|
||||
@ -23,16 +24,23 @@ class PersistableConfigProtocol(Protocol):
|
||||
pass
|
||||
|
||||
|
||||
class Persistence(ABC):
|
||||
def path(self, world: World, filename: str) -> str:
|
||||
return os.path.join(world.directory, filename)
|
||||
class PersistEvent(Enum):
|
||||
PERSIST_POLICY = "persist_policy"
|
||||
"""Policy neural network is persisted (new best found)"""
|
||||
|
||||
|
||||
class RestoreEvent(Enum):
|
||||
RESTORE_POLICY = "restore_policy"
|
||||
"""Policy neural network parameters are restored"""
|
||||
|
||||
|
||||
class Persistence(ABC):
|
||||
@abstractmethod
|
||||
def persist(self, world: World) -> None:
|
||||
def persist(self, event: PersistEvent, world: World) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def restore(self, world: World):
|
||||
def restore(self, event: RestoreEvent, world: World):
|
||||
pass
|
||||
|
||||
|
||||
@ -40,31 +48,38 @@ class PersistenceGroup(Persistence):
|
||||
def __init__(self, *p: Persistence):
|
||||
self.items = p
|
||||
|
||||
def persist(self, world: World) -> None:
|
||||
def persist(self, event: PersistEvent, world: World) -> None:
|
||||
for item in self.items:
|
||||
item.persist(world)
|
||||
item.persist(event, world)
|
||||
|
||||
def restore(self, world: World):
|
||||
def restore(self, event: RestoreEvent, world: World):
|
||||
for item in self.items:
|
||||
item.restore(world)
|
||||
item.restore(event, world)
|
||||
|
||||
|
||||
class PolicyPersistence:
|
||||
FILENAME = "policy.dat"
|
||||
|
||||
def persist(self, policy: torch.nn.Module, directory: str) -> None:
|
||||
path = os.path.join(directory, self.FILENAME)
|
||||
def __init__(self, additional_persistence: Persistence | None = None):
|
||||
self.additional_persistence = additional_persistence
|
||||
|
||||
def persist(self, policy: torch.nn.Module, world: World) -> None:
|
||||
path = world.persist_path(self.FILENAME)
|
||||
log.info(f"Saving policy in {path}")
|
||||
torch.save(policy.state_dict(), path)
|
||||
if self.additional_persistence is not None:
|
||||
self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world)
|
||||
|
||||
def restore(self, policy: torch.nn.Module, directory: str, device: "TDevice") -> None:
|
||||
path = os.path.join(directory, self.FILENAME)
|
||||
def restore(self, policy: torch.nn.Module, world: World, device: "TDevice") -> None:
|
||||
path = world.restore_path(self.FILENAME)
|
||||
log.info(f"Restoring policy from {path}")
|
||||
state_dict = torch.load(path, map_location=device)
|
||||
policy.load_state_dict(state_dict)
|
||||
if self.additional_persistence is not None:
|
||||
self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)
|
||||
|
||||
def get_save_best_fn(self, world):
|
||||
def get_save_best_fn(self, world) -> Callable[[torch.nn.Module], None]:
|
||||
def save_best_fn(pol: torch.nn.Module) -> None:
|
||||
self.persist(pol, world.directory)
|
||||
self.persist(pol, world)
|
||||
|
||||
return save_best_fn
|
||||
|
@ -17,11 +17,16 @@ class World:
|
||||
train_collector: "Collector"
|
||||
test_collector: "Collector"
|
||||
logger: "Logger"
|
||||
restore_directory: str
|
||||
trainer: Optional["BaseTrainer"] = None
|
||||
|
||||
|
||||
@property
|
||||
def directory(self):
|
||||
def persist_directory(self):
|
||||
return self.logger.log_path
|
||||
|
||||
def path(self, filename: str) -> str:
|
||||
return os.path.join(self.directory, filename)
|
||||
def persist_path(self, filename: str) -> str:
|
||||
return os.path.join(self.persist_directory, filename)
|
||||
|
||||
def restore_path(self, filename: str) -> str:
|
||||
return os.path.join(self.restore_directory, filename)
|
||||
|
Loading…
x
Reference in New Issue
Block a user