Fix/add watch env with obs rms (#1061)
Supports deciding whether to watch the agent performing on the env using high-level interfaces
This commit is contained in:
parent
49781e715e
commit
7c970df53f
@ -1,10 +1,11 @@
|
||||
import logging
|
||||
import pickle
|
||||
|
||||
from tianshou.env import VectorEnvNormObs
|
||||
from tianshou.env import BaseVectorEnv, VectorEnvNormObs
|
||||
from tianshou.highlevel.env import (
|
||||
ContinuousEnvironments,
|
||||
EnvFactoryRegistered,
|
||||
EnvMode,
|
||||
EnvPoolFactory,
|
||||
VectorEnvType,
|
||||
)
|
||||
@ -56,6 +57,8 @@ class MujocoEnvObsRmsPersistence(Persistence):
|
||||
obs_rms = pickle.load(f)
|
||||
world.envs.train_envs.set_obs_rms(obs_rms)
|
||||
world.envs.test_envs.set_obs_rms(obs_rms)
|
||||
if world.envs.watch_env is not None:
|
||||
world.envs.watch_env.set_obs_rms(obs_rms)
|
||||
|
||||
|
||||
class MujocoEnvFactory(EnvFactoryRegistered):
|
||||
@ -68,15 +71,31 @@ class MujocoEnvFactory(EnvFactoryRegistered):
|
||||
)
|
||||
self.obs_norm = obs_norm
|
||||
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments:
|
||||
envs = super().create_envs(num_training_envs, num_test_envs)
|
||||
assert isinstance(envs, ContinuousEnvironments)
|
||||
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
|
||||
"""Create vectorized environments.
|
||||
|
||||
:param num_envs: the number of environments
|
||||
:param mode: the mode for which to create
|
||||
:return: the vectorized environments
|
||||
"""
|
||||
env = super().create_venv(num_envs, mode)
|
||||
# obs norm wrapper
|
||||
if self.obs_norm:
|
||||
envs.train_envs = VectorEnvNormObs(envs.train_envs)
|
||||
envs.test_envs = VectorEnvNormObs(envs.test_envs, update_obs_rms=False)
|
||||
envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms())
|
||||
envs.set_persistence(MujocoEnvObsRmsPersistence())
|
||||
env = VectorEnvNormObs(env, update_obs_rms=mode == EnvMode.TRAIN)
|
||||
return env
|
||||
|
||||
def create_envs(
|
||||
self,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
create_watch_env: bool = False,
|
||||
) -> ContinuousEnvironments:
|
||||
envs = super().create_envs(num_training_envs, num_test_envs, create_watch_env)
|
||||
assert isinstance(envs, ContinuousEnvironments)
|
||||
|
||||
if self.obs_norm:
|
||||
envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms())
|
||||
if envs.watch_env is not None:
|
||||
envs.watch_env.set_obs_rms(envs.train_envs.get_obs_rms())
|
||||
envs.set_persistence(MujocoEnvObsRmsPersistence())
|
||||
return envs
|
||||
|
@ -89,10 +89,17 @@ class VectorEnvType(Enum):
|
||||
class Environments(ToStringMixin, ABC):
|
||||
"""Represents (vectorized) environments for a learning process."""
|
||||
|
||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
train_envs: BaseVectorEnv,
|
||||
test_envs: BaseVectorEnv,
|
||||
watch_env: BaseVectorEnv | None = None,
|
||||
):
|
||||
self.env = env
|
||||
self.train_envs = train_envs
|
||||
self.test_envs = test_envs
|
||||
self.watch_env = watch_env
|
||||
self.persistence: Sequence[Persistence] = []
|
||||
|
||||
@staticmethod
|
||||
@ -102,6 +109,7 @@ class Environments(ToStringMixin, ABC):
|
||||
venv_type: VectorEnvType,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
create_watch_env: bool = False,
|
||||
) -> "Environments":
|
||||
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
|
||||
|
||||
@ -110,16 +118,21 @@ class Environments(ToStringMixin, ABC):
|
||||
:param venv_type: the vector environment type to use for parallelization
|
||||
:param num_training_envs: the number of training environments to create
|
||||
:param num_test_envs: the number of test environments to create
|
||||
:param create_watch_env: whether to create an environment for watching the agent
|
||||
:return: the instance
|
||||
"""
|
||||
train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs)
|
||||
test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs)
|
||||
if create_watch_env:
|
||||
watch_env = venv_type.create_venv([lambda: factory_fn(EnvMode.WATCH)])
|
||||
else:
|
||||
watch_env = None
|
||||
env = factory_fn(EnvMode.TRAIN)
|
||||
match env_type:
|
||||
case EnvType.CONTINUOUS:
|
||||
return ContinuousEnvironments(env, train_envs, test_envs)
|
||||
return ContinuousEnvironments(env, train_envs, test_envs, watch_env)
|
||||
case EnvType.DISCRETE:
|
||||
return DiscreteEnvironments(env, train_envs, test_envs)
|
||||
return DiscreteEnvironments(env, train_envs, test_envs, watch_env)
|
||||
case _:
|
||||
raise ValueError(f"Environment type {env_type} not handled")
|
||||
|
||||
@ -164,8 +177,14 @@ class Environments(ToStringMixin, ABC):
|
||||
class ContinuousEnvironments(Environments):
|
||||
"""Represents (vectorized) continuous environments."""
|
||||
|
||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
super().__init__(env, train_envs, test_envs)
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
train_envs: BaseVectorEnv,
|
||||
test_envs: BaseVectorEnv,
|
||||
watch_env: BaseVectorEnv | None = None,
|
||||
):
|
||||
super().__init__(env, train_envs, test_envs, watch_env)
|
||||
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
||||
|
||||
@staticmethod
|
||||
@ -174,6 +193,7 @@ class ContinuousEnvironments(Environments):
|
||||
venv_type: VectorEnvType,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
create_watch_env: bool = False,
|
||||
) -> "ContinuousEnvironments":
|
||||
"""Creates an instance from a factory function that creates a single instance.
|
||||
|
||||
@ -181,6 +201,7 @@ class ContinuousEnvironments(Environments):
|
||||
:param venv_type: the vector environment type to use for parallelization
|
||||
:param num_training_envs: the number of training environments to create
|
||||
:param num_test_envs: the number of test environments to create
|
||||
:param create_watch_env: whether to create an environment for watching the agent
|
||||
:return: the instance
|
||||
"""
|
||||
return cast(
|
||||
@ -191,6 +212,7 @@ class ContinuousEnvironments(Environments):
|
||||
venv_type,
|
||||
num_training_envs,
|
||||
num_test_envs,
|
||||
create_watch_env,
|
||||
),
|
||||
)
|
||||
|
||||
@ -228,8 +250,14 @@ class ContinuousEnvironments(Environments):
|
||||
class DiscreteEnvironments(Environments):
|
||||
"""Represents (vectorized) discrete environments."""
|
||||
|
||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
super().__init__(env, train_envs, test_envs)
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
train_envs: BaseVectorEnv,
|
||||
test_envs: BaseVectorEnv,
|
||||
watch_env: BaseVectorEnv | None = None,
|
||||
):
|
||||
super().__init__(env, train_envs, test_envs, watch_env)
|
||||
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
|
||||
self.action_shape = env.action_space.shape or env.action_space.n # type: ignore
|
||||
|
||||
@ -239,6 +267,7 @@ class DiscreteEnvironments(Environments):
|
||||
venv_type: VectorEnvType,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
create_watch_env: bool = False,
|
||||
) -> "DiscreteEnvironments":
|
||||
"""Creates an instance from a factory function that creates a single instance.
|
||||
|
||||
@ -246,6 +275,7 @@ class DiscreteEnvironments(Environments):
|
||||
:param venv_type: the vector environment type to use for parallelization
|
||||
:param num_training_envs: the number of training environments to create
|
||||
:param num_test_envs: the number of test environments to create
|
||||
:param create_watch_env: whether to create an environment for watching the agent
|
||||
:return: the instance
|
||||
"""
|
||||
return cast(
|
||||
@ -256,6 +286,7 @@ class DiscreteEnvironments(Environments):
|
||||
venv_type,
|
||||
num_training_envs,
|
||||
num_test_envs,
|
||||
create_watch_env,
|
||||
),
|
||||
)
|
||||
|
||||
@ -329,21 +360,28 @@ class EnvFactory(ToStringMixin, ABC):
|
||||
"""
|
||||
return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs)
|
||||
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
||||
def create_envs(
|
||||
self,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
create_watch_env: bool = False,
|
||||
) -> Environments:
|
||||
"""Create environments for learning.
|
||||
|
||||
:param num_training_envs: the number of training environments
|
||||
:param num_test_envs: the number of test environments
|
||||
:param create_watch_env: whether to create an environment for watching the agent
|
||||
:return: the environments
|
||||
"""
|
||||
env = self.create_env(EnvMode.TRAIN)
|
||||
train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN)
|
||||
test_envs = self.create_venv(num_test_envs, EnvMode.TEST)
|
||||
watch_env = self.create_venv(1, EnvMode.WATCH) if create_watch_env else None
|
||||
match EnvType.from_env(env):
|
||||
case EnvType.DISCRETE:
|
||||
return DiscreteEnvironments(env, train_envs, test_envs)
|
||||
return DiscreteEnvironments(env, train_envs, test_envs, watch_env)
|
||||
case EnvType.CONTINUOUS:
|
||||
return ContinuousEnvironments(env, train_envs, test_envs)
|
||||
return ContinuousEnvironments(env, train_envs, test_envs, watch_env)
|
||||
case _:
|
||||
raise ValueError
|
||||
|
||||
|
@ -10,6 +10,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.data import Collector, InfoStats
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.highlevel.agent import (
|
||||
A2CAgentFactory,
|
||||
AgentFactory,
|
||||
@ -26,7 +27,7 @@ from tianshou.highlevel.agent import (
|
||||
TRPOAgentFactory,
|
||||
)
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import EnvFactory, EnvMode
|
||||
from tianshou.highlevel.env import EnvFactory
|
||||
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
|
||||
from tianshou.highlevel.module.actor import (
|
||||
ActorFactory,
|
||||
@ -222,6 +223,7 @@ class Experiment(ToStringMixin):
|
||||
envs = self.env_factory.create_envs(
|
||||
self.sampling_config.num_train_envs,
|
||||
self.sampling_config.num_test_envs,
|
||||
create_watch_env=self.config.watch,
|
||||
)
|
||||
log.info(f"Created {envs}")
|
||||
|
||||
@ -289,11 +291,12 @@ class Experiment(ToStringMixin):
|
||||
|
||||
# watch agent performance
|
||||
if self.config.watch:
|
||||
assert envs.watch_env is not None
|
||||
log.info("Watching agent performance")
|
||||
self._watch_agent(
|
||||
self.config.watch_num_episodes,
|
||||
policy,
|
||||
self.env_factory,
|
||||
envs.watch_env,
|
||||
self.config.watch_render,
|
||||
)
|
||||
|
||||
@ -303,11 +306,10 @@ class Experiment(ToStringMixin):
|
||||
def _watch_agent(
|
||||
num_episodes: int,
|
||||
policy: BasePolicy,
|
||||
env_factory: EnvFactory,
|
||||
env: BaseVectorEnv,
|
||||
render: float,
|
||||
) -> None:
|
||||
policy.eval()
|
||||
env = env_factory.create_env(EnvMode.WATCH)
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=num_episodes, render=render)
|
||||
assert result.returns_stat is not None # for mypy
|
||||
|
Loading…
x
Reference in New Issue
Block a user