Tianshou/examples/mujoco/mujoco_env.py
maxhuettenrauch ade85ab32b
Feature/algo eval (#1074)
# Changes

## Dependencies

- New extra "eval"

## Api Extension
- `Experiment` and `ExperimentConfig` now have a `name`, that can
however be overridden when `Experiment.run()` is called
- When building an `Experiment` from an `ExperimentConfig`, the user has
the option to add info about seeds to the name.
- New method in `ExperimentConfig` called
`build_default_seeded_experiments`
- `SamplingConfig` has an explicit training seed, `test_seed` is
inferred.
- New `evaluation` package for repeating the same experiment with
multiple seeds and aggregating the results (important extension!).
Currently in alpha state.
- Loggers can now restore the logged data into python by using the new
`restore_logged_data`

## Breaking Changes
- `AtariEnvFactory` (in examples) now receives explicit train and test
seeds
- `EnvFactoryRegistered` now requires an explicit `test_seed`
- `BaseLogger.prepare_dict_for_logging` is now abstract

---------

Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
Co-authored-by: Michael Panchenko <35432522+MischaPanch@users.noreply.github.com>
2024-04-20 23:25:33 +00:00

118 lines
3.8 KiB
Python

import logging
import pickle
from gymnasium import Env
from tianshou.env import BaseVectorEnv, VectorEnvNormObs
from tianshou.highlevel.env import (
ContinuousEnvironments,
EnvFactoryRegistered,
EnvMode,
EnvPoolFactory,
VectorEnvType,
)
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
from tianshou.highlevel.world import World
envpool_is_available = True
try:
import envpool
except ImportError:
envpool_is_available = False
envpool = None
log = logging.getLogger(__name__)
def make_mujoco_env(
task: str,
seed: int,
num_train_envs: int,
num_test_envs: int,
obs_norm: bool,
) -> tuple[Env, BaseVectorEnv, BaseVectorEnv]:
"""Wrapper function for Mujoco env.
If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.
:return: a tuple of (single env, training envs, test envs).
"""
envs = MujocoEnvFactory(task, seed, seed + num_train_envs, obs_norm=obs_norm).create_envs(
num_train_envs,
num_test_envs,
)
return envs.env, envs.train_envs, envs.test_envs
class MujocoEnvObsRmsPersistence(Persistence):
FILENAME = "env_obs_rms.pkl"
def persist(self, event: PersistEvent, world: World) -> None:
if event != PersistEvent.PERSIST_POLICY:
return # type: ignore[unreachable] # since PersistEvent has only one member, mypy infers that line is unreachable
obs_rms = world.envs.train_envs.get_obs_rms()
path = world.persist_path(self.FILENAME)
log.info(f"Saving environment obs_rms value to {path}")
with open(path, "wb") as f:
pickle.dump(obs_rms, f)
def restore(self, event: RestoreEvent, world: World) -> None:
if event != RestoreEvent.RESTORE_POLICY:
return # type: ignore[unreachable]
path = world.restore_path(self.FILENAME)
log.info(f"Restoring environment obs_rms value from {path}")
with open(path, "rb") as f:
obs_rms = pickle.load(f)
world.envs.train_envs.set_obs_rms(obs_rms)
world.envs.test_envs.set_obs_rms(obs_rms)
if world.envs.watch_env is not None:
world.envs.watch_env.set_obs_rms(obs_rms)
class MujocoEnvFactory(EnvFactoryRegistered):
def __init__(
self,
task: str,
train_seed: int,
test_seed: int,
obs_norm: bool = True,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM,
) -> None:
super().__init__(
task=task,
train_seed=train_seed,
test_seed=test_seed,
venv_type=venv_type,
envpool_factory=EnvPoolFactory() if envpool_is_available else None,
)
self.obs_norm = obs_norm
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
"""Create vectorized environments.
:param num_envs: the number of environments
:param mode: the mode for which to create
:return: the vectorized environments
"""
env = super().create_venv(num_envs, mode)
# obs norm wrapper
if self.obs_norm:
env = VectorEnvNormObs(env, update_obs_rms=mode == EnvMode.TRAIN)
return env
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
create_watch_env: bool = False,
) -> ContinuousEnvironments:
envs = super().create_envs(num_training_envs, num_test_envs, create_watch_env)
assert isinstance(envs, ContinuousEnvironments)
if self.obs_norm:
envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms())
if envs.watch_env is not None:
envs.watch_env.set_obs_rms(envs.train_envs.get_obs_rms())
envs.set_persistence(MujocoEnvObsRmsPersistence())
return envs