Support obs_rms persistence for MuJoCo by adding a general mechanism
for attaching persistence to Environments instances
This commit is contained in:
parent
f6d49774a2
commit
3691ed2abc
@ -1,16 +1,22 @@
|
|||||||
|
import pickle
|
||||||
import warnings
|
import warnings
|
||||||
|
import logging
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
||||||
|
from tianshou.highlevel.persistence import Persistence, RestoreEvent, PersistEvent
|
||||||
|
from tianshou.highlevel.world import World
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import envpool
|
import envpool
|
||||||
except ImportError:
|
except ImportError:
|
||||||
envpool = None
|
envpool = None
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool):
|
def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool):
|
||||||
"""Wrapper function for Mujoco env.
|
"""Wrapper function for Mujoco env.
|
||||||
@ -40,6 +46,29 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
|
|||||||
return env, train_envs, test_envs
|
return env, train_envs, test_envs
|
||||||
|
|
||||||
|
|
||||||
|
class MujocoEnvObsRmsPersistence(Persistence):
|
||||||
|
FILENAME = "env_obs_rms.pkl"
|
||||||
|
|
||||||
|
def persist(self, event: PersistEvent, world: World) -> None:
|
||||||
|
if event != PersistEvent.PERSIST_POLICY:
|
||||||
|
return
|
||||||
|
obs_rms = world.envs.train_envs.get_obs_rms()
|
||||||
|
path = world.persist_path(self.FILENAME)
|
||||||
|
log.info(f"Saving environment obs_rms value to {path}")
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
pickle.dump(obs_rms, f)
|
||||||
|
|
||||||
|
def restore(self, event: RestoreEvent, world: World):
|
||||||
|
if event != RestoreEvent.RESTORE_POLICY:
|
||||||
|
return
|
||||||
|
path = world.restore_path(self.FILENAME)
|
||||||
|
log.info(f"Restoring environment obs_rms value from {path}")
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
obs_rms = pickle.load(f)
|
||||||
|
world.envs.train_envs.set_obs_rms(obs_rms)
|
||||||
|
world.envs.test_envs.set_obs_rms(obs_rms)
|
||||||
|
|
||||||
|
|
||||||
class MujocoEnvFactory(EnvFactory):
|
class MujocoEnvFactory(EnvFactory):
|
||||||
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig):
|
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig):
|
||||||
self.task = task
|
self.task = task
|
||||||
@ -54,4 +83,6 @@ class MujocoEnvFactory(EnvFactory):
|
|||||||
num_test_envs=self.sampling_config.num_test_envs,
|
num_test_envs=self.sampling_config.num_test_envs,
|
||||||
obs_norm=True,
|
obs_norm=True,
|
||||||
)
|
)
|
||||||
return ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||||
|
envs.set_persistence(MujocoEnvObsRmsPersistence())
|
||||||
|
return envs
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, TypeAlias
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from tianshou.env import BaseVectorEnv
|
from tianshou.env import BaseVectorEnv
|
||||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
from tianshou.highlevel.persistence import PersistableConfigProtocol, Persistence
|
||||||
from tianshou.utils.net.common import TActionShape
|
from tianshou.utils.net.common import TActionShape
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
@ -37,6 +37,7 @@ class Environments(ToStringMixin, ABC):
|
|||||||
self.env = env
|
self.env = env
|
||||||
self.train_envs = train_envs
|
self.train_envs = train_envs
|
||||||
self.test_envs = test_envs
|
self.test_envs = test_envs
|
||||||
|
self.persistence = []
|
||||||
|
|
||||||
def _tostring_includes(self) -> list[str]:
|
def _tostring_includes(self) -> list[str]:
|
||||||
return []
|
return []
|
||||||
@ -50,6 +51,9 @@ class Environments(ToStringMixin, ABC):
|
|||||||
"state_shape": self.get_observation_shape(),
|
"state_shape": self.get_observation_shape(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def set_persistence(self, *p: Persistence):
|
||||||
|
self.persistence = p
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_action_shape(self) -> TActionShape:
|
def get_action_shape(self) -> TActionShape:
|
||||||
pass
|
pass
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
@ -64,7 +65,7 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
TRPOParams,
|
TRPOParams,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||||
from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence
|
from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence, PersistenceGroup
|
||||||
from tianshou.highlevel.trainer import (
|
from tianshou.highlevel.trainer import (
|
||||||
TrainerCallbacks,
|
TrainerCallbacks,
|
||||||
TrainerEpochCallbackTest,
|
TrainerEpochCallbackTest,
|
||||||
@ -127,7 +128,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
|
|
||||||
def _build_config_dict(self) -> dict:
|
def _build_config_dict(self) -> dict:
|
||||||
return {
|
return {
|
||||||
# TODO
|
"experiment": self.pprints()
|
||||||
}
|
}
|
||||||
|
|
||||||
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
|
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
|
||||||
@ -143,14 +144,21 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
if experiment_name is None:
|
if experiment_name is None:
|
||||||
experiment_name = datetime_tag()
|
experiment_name = datetime_tag()
|
||||||
|
|
||||||
|
log.info(f"Working directory: {os.getcwd()}")
|
||||||
|
|
||||||
self._set_seed()
|
self._set_seed()
|
||||||
|
|
||||||
|
# create environments
|
||||||
envs = self.env_factory(self.env_config)
|
envs = self.env_factory(self.env_config)
|
||||||
policy_persistence = PolicyPersistence()
|
|
||||||
log.info(f"Created {envs}")
|
log.info(f"Created {envs}")
|
||||||
|
|
||||||
|
# initialize persistence
|
||||||
|
additional_persistence = PersistenceGroup(*envs.persistence)
|
||||||
|
policy_persistence = PolicyPersistence(additional_persistence)
|
||||||
|
|
||||||
|
# initialize logger
|
||||||
full_config = self._build_config_dict()
|
full_config = self._build_config_dict()
|
||||||
full_config.update(envs.info())
|
full_config.update(envs.info())
|
||||||
|
|
||||||
logger = self.logger_factory.create_logger(
|
logger = self.logger_factory.create_logger(
|
||||||
log_name=experiment_name,
|
log_name=experiment_name,
|
||||||
run_id=logger_run_id,
|
run_id=logger_run_id,
|
||||||
@ -170,12 +178,13 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
train_collector=train_collector,
|
train_collector=train_collector,
|
||||||
test_collector=test_collector,
|
test_collector=test_collector,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
restore_directory=self.config.policy_restore_directory
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.policy_restore_directory:
|
if self.config.policy_restore_directory:
|
||||||
policy_persistence.restore(
|
policy_persistence.restore(
|
||||||
policy,
|
policy,
|
||||||
self.config.policy_restore_directory,
|
world,
|
||||||
self.config.device,
|
self.config.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable
|
from enum import Enum
|
||||||
|
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable, Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -23,16 +24,23 @@ class PersistableConfigProtocol(Protocol):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Persistence(ABC):
|
class PersistEvent(Enum):
|
||||||
def path(self, world: World, filename: str) -> str:
|
PERSIST_POLICY = "persist_policy"
|
||||||
return os.path.join(world.directory, filename)
|
"""Policy neural network is persisted (new best found)"""
|
||||||
|
|
||||||
|
|
||||||
|
class RestoreEvent(Enum):
|
||||||
|
RESTORE_POLICY = "restore_policy"
|
||||||
|
"""Policy neural network parameters are restored"""
|
||||||
|
|
||||||
|
|
||||||
|
class Persistence(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def persist(self, world: World) -> None:
|
def persist(self, event: PersistEvent, world: World) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def restore(self, world: World):
|
def restore(self, event: RestoreEvent, world: World):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -40,31 +48,38 @@ class PersistenceGroup(Persistence):
|
|||||||
def __init__(self, *p: Persistence):
|
def __init__(self, *p: Persistence):
|
||||||
self.items = p
|
self.items = p
|
||||||
|
|
||||||
def persist(self, world: World) -> None:
|
def persist(self, event: PersistEvent, world: World) -> None:
|
||||||
for item in self.items:
|
for item in self.items:
|
||||||
item.persist(world)
|
item.persist(event, world)
|
||||||
|
|
||||||
def restore(self, world: World):
|
def restore(self, event: RestoreEvent, world: World):
|
||||||
for item in self.items:
|
for item in self.items:
|
||||||
item.restore(world)
|
item.restore(event, world)
|
||||||
|
|
||||||
|
|
||||||
class PolicyPersistence:
|
class PolicyPersistence:
|
||||||
FILENAME = "policy.dat"
|
FILENAME = "policy.dat"
|
||||||
|
|
||||||
def persist(self, policy: torch.nn.Module, directory: str) -> None:
|
def __init__(self, additional_persistence: Persistence | None = None):
|
||||||
path = os.path.join(directory, self.FILENAME)
|
self.additional_persistence = additional_persistence
|
||||||
|
|
||||||
|
def persist(self, policy: torch.nn.Module, world: World) -> None:
|
||||||
|
path = world.persist_path(self.FILENAME)
|
||||||
log.info(f"Saving policy in {path}")
|
log.info(f"Saving policy in {path}")
|
||||||
torch.save(policy.state_dict(), path)
|
torch.save(policy.state_dict(), path)
|
||||||
|
if self.additional_persistence is not None:
|
||||||
|
self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world)
|
||||||
|
|
||||||
def restore(self, policy: torch.nn.Module, directory: str, device: "TDevice") -> None:
|
def restore(self, policy: torch.nn.Module, world: World, device: "TDevice") -> None:
|
||||||
path = os.path.join(directory, self.FILENAME)
|
path = world.restore_path(self.FILENAME)
|
||||||
log.info(f"Restoring policy from {path}")
|
log.info(f"Restoring policy from {path}")
|
||||||
state_dict = torch.load(path, map_location=device)
|
state_dict = torch.load(path, map_location=device)
|
||||||
policy.load_state_dict(state_dict)
|
policy.load_state_dict(state_dict)
|
||||||
|
if self.additional_persistence is not None:
|
||||||
|
self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)
|
||||||
|
|
||||||
def get_save_best_fn(self, world):
|
def get_save_best_fn(self, world) -> Callable[[torch.nn.Module], None]:
|
||||||
def save_best_fn(pol: torch.nn.Module) -> None:
|
def save_best_fn(pol: torch.nn.Module) -> None:
|
||||||
self.persist(pol, world.directory)
|
self.persist(pol, world)
|
||||||
|
|
||||||
return save_best_fn
|
return save_best_fn
|
||||||
|
@ -17,11 +17,16 @@ class World:
|
|||||||
train_collector: "Collector"
|
train_collector: "Collector"
|
||||||
test_collector: "Collector"
|
test_collector: "Collector"
|
||||||
logger: "Logger"
|
logger: "Logger"
|
||||||
|
restore_directory: str
|
||||||
trainer: Optional["BaseTrainer"] = None
|
trainer: Optional["BaseTrainer"] = None
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def directory(self):
|
def persist_directory(self):
|
||||||
return self.logger.log_path
|
return self.logger.log_path
|
||||||
|
|
||||||
def path(self, filename: str) -> str:
|
def persist_path(self, filename: str) -> str:
|
||||||
return os.path.join(self.directory, filename)
|
return os.path.join(self.persist_directory, filename)
|
||||||
|
|
||||||
|
def restore_path(self, filename: str) -> str:
|
||||||
|
return os.path.join(self.restore_directory, filename)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user