Fix/add watch env with obs rms (#1061)
Supports deciding whether to watch the agent performing on the env using high-level interfaces
This commit is contained in:
parent
49781e715e
commit
7c970df53f
@ -1,10 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
from tianshou.env import VectorEnvNormObs
|
from tianshou.env import BaseVectorEnv, VectorEnvNormObs
|
||||||
from tianshou.highlevel.env import (
|
from tianshou.highlevel.env import (
|
||||||
ContinuousEnvironments,
|
ContinuousEnvironments,
|
||||||
EnvFactoryRegistered,
|
EnvFactoryRegistered,
|
||||||
|
EnvMode,
|
||||||
EnvPoolFactory,
|
EnvPoolFactory,
|
||||||
VectorEnvType,
|
VectorEnvType,
|
||||||
)
|
)
|
||||||
@ -56,6 +57,8 @@ class MujocoEnvObsRmsPersistence(Persistence):
|
|||||||
obs_rms = pickle.load(f)
|
obs_rms = pickle.load(f)
|
||||||
world.envs.train_envs.set_obs_rms(obs_rms)
|
world.envs.train_envs.set_obs_rms(obs_rms)
|
||||||
world.envs.test_envs.set_obs_rms(obs_rms)
|
world.envs.test_envs.set_obs_rms(obs_rms)
|
||||||
|
if world.envs.watch_env is not None:
|
||||||
|
world.envs.watch_env.set_obs_rms(obs_rms)
|
||||||
|
|
||||||
|
|
||||||
class MujocoEnvFactory(EnvFactoryRegistered):
|
class MujocoEnvFactory(EnvFactoryRegistered):
|
||||||
@ -68,15 +71,31 @@ class MujocoEnvFactory(EnvFactoryRegistered):
|
|||||||
)
|
)
|
||||||
self.obs_norm = obs_norm
|
self.obs_norm = obs_norm
|
||||||
|
|
||||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments:
|
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
|
||||||
envs = super().create_envs(num_training_envs, num_test_envs)
|
"""Create vectorized environments.
|
||||||
assert isinstance(envs, ContinuousEnvironments)
|
|
||||||
|
|
||||||
|
:param num_envs: the number of environments
|
||||||
|
:param mode: the mode for which to create
|
||||||
|
:return: the vectorized environments
|
||||||
|
"""
|
||||||
|
env = super().create_venv(num_envs, mode)
|
||||||
# obs norm wrapper
|
# obs norm wrapper
|
||||||
if self.obs_norm:
|
if self.obs_norm:
|
||||||
envs.train_envs = VectorEnvNormObs(envs.train_envs)
|
env = VectorEnvNormObs(env, update_obs_rms=mode == EnvMode.TRAIN)
|
||||||
envs.test_envs = VectorEnvNormObs(envs.test_envs, update_obs_rms=False)
|
return env
|
||||||
envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms())
|
|
||||||
envs.set_persistence(MujocoEnvObsRmsPersistence())
|
|
||||||
|
|
||||||
|
def create_envs(
|
||||||
|
self,
|
||||||
|
num_training_envs: int,
|
||||||
|
num_test_envs: int,
|
||||||
|
create_watch_env: bool = False,
|
||||||
|
) -> ContinuousEnvironments:
|
||||||
|
envs = super().create_envs(num_training_envs, num_test_envs, create_watch_env)
|
||||||
|
assert isinstance(envs, ContinuousEnvironments)
|
||||||
|
|
||||||
|
if self.obs_norm:
|
||||||
|
envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms())
|
||||||
|
if envs.watch_env is not None:
|
||||||
|
envs.watch_env.set_obs_rms(envs.train_envs.get_obs_rms())
|
||||||
|
envs.set_persistence(MujocoEnvObsRmsPersistence())
|
||||||
return envs
|
return envs
|
||||||
|
@ -89,10 +89,17 @@ class VectorEnvType(Enum):
|
|||||||
class Environments(ToStringMixin, ABC):
|
class Environments(ToStringMixin, ABC):
|
||||||
"""Represents (vectorized) environments for a learning process."""
|
"""Represents (vectorized) environments for a learning process."""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: gym.Env,
|
||||||
|
train_envs: BaseVectorEnv,
|
||||||
|
test_envs: BaseVectorEnv,
|
||||||
|
watch_env: BaseVectorEnv | None = None,
|
||||||
|
):
|
||||||
self.env = env
|
self.env = env
|
||||||
self.train_envs = train_envs
|
self.train_envs = train_envs
|
||||||
self.test_envs = test_envs
|
self.test_envs = test_envs
|
||||||
|
self.watch_env = watch_env
|
||||||
self.persistence: Sequence[Persistence] = []
|
self.persistence: Sequence[Persistence] = []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -102,6 +109,7 @@ class Environments(ToStringMixin, ABC):
|
|||||||
venv_type: VectorEnvType,
|
venv_type: VectorEnvType,
|
||||||
num_training_envs: int,
|
num_training_envs: int,
|
||||||
num_test_envs: int,
|
num_test_envs: int,
|
||||||
|
create_watch_env: bool = False,
|
||||||
) -> "Environments":
|
) -> "Environments":
|
||||||
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
|
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
|
||||||
|
|
||||||
@ -110,16 +118,21 @@ class Environments(ToStringMixin, ABC):
|
|||||||
:param venv_type: the vector environment type to use for parallelization
|
:param venv_type: the vector environment type to use for parallelization
|
||||||
:param num_training_envs: the number of training environments to create
|
:param num_training_envs: the number of training environments to create
|
||||||
:param num_test_envs: the number of test environments to create
|
:param num_test_envs: the number of test environments to create
|
||||||
|
:param create_watch_env: whether to create an environment for watching the agent
|
||||||
:return: the instance
|
:return: the instance
|
||||||
"""
|
"""
|
||||||
train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs)
|
train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs)
|
||||||
test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs)
|
test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs)
|
||||||
|
if create_watch_env:
|
||||||
|
watch_env = venv_type.create_venv([lambda: factory_fn(EnvMode.WATCH)])
|
||||||
|
else:
|
||||||
|
watch_env = None
|
||||||
env = factory_fn(EnvMode.TRAIN)
|
env = factory_fn(EnvMode.TRAIN)
|
||||||
match env_type:
|
match env_type:
|
||||||
case EnvType.CONTINUOUS:
|
case EnvType.CONTINUOUS:
|
||||||
return ContinuousEnvironments(env, train_envs, test_envs)
|
return ContinuousEnvironments(env, train_envs, test_envs, watch_env)
|
||||||
case EnvType.DISCRETE:
|
case EnvType.DISCRETE:
|
||||||
return DiscreteEnvironments(env, train_envs, test_envs)
|
return DiscreteEnvironments(env, train_envs, test_envs, watch_env)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Environment type {env_type} not handled")
|
raise ValueError(f"Environment type {env_type} not handled")
|
||||||
|
|
||||||
@ -164,8 +177,14 @@ class Environments(ToStringMixin, ABC):
|
|||||||
class ContinuousEnvironments(Environments):
|
class ContinuousEnvironments(Environments):
|
||||||
"""Represents (vectorized) continuous environments."""
|
"""Represents (vectorized) continuous environments."""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
def __init__(
|
||||||
super().__init__(env, train_envs, test_envs)
|
self,
|
||||||
|
env: gym.Env,
|
||||||
|
train_envs: BaseVectorEnv,
|
||||||
|
test_envs: BaseVectorEnv,
|
||||||
|
watch_env: BaseVectorEnv | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(env, train_envs, test_envs, watch_env)
|
||||||
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -174,6 +193,7 @@ class ContinuousEnvironments(Environments):
|
|||||||
venv_type: VectorEnvType,
|
venv_type: VectorEnvType,
|
||||||
num_training_envs: int,
|
num_training_envs: int,
|
||||||
num_test_envs: int,
|
num_test_envs: int,
|
||||||
|
create_watch_env: bool = False,
|
||||||
) -> "ContinuousEnvironments":
|
) -> "ContinuousEnvironments":
|
||||||
"""Creates an instance from a factory function that creates a single instance.
|
"""Creates an instance from a factory function that creates a single instance.
|
||||||
|
|
||||||
@ -181,6 +201,7 @@ class ContinuousEnvironments(Environments):
|
|||||||
:param venv_type: the vector environment type to use for parallelization
|
:param venv_type: the vector environment type to use for parallelization
|
||||||
:param num_training_envs: the number of training environments to create
|
:param num_training_envs: the number of training environments to create
|
||||||
:param num_test_envs: the number of test environments to create
|
:param num_test_envs: the number of test environments to create
|
||||||
|
:param create_watch_env: whether to create an environment for watching the agent
|
||||||
:return: the instance
|
:return: the instance
|
||||||
"""
|
"""
|
||||||
return cast(
|
return cast(
|
||||||
@ -191,6 +212,7 @@ class ContinuousEnvironments(Environments):
|
|||||||
venv_type,
|
venv_type,
|
||||||
num_training_envs,
|
num_training_envs,
|
||||||
num_test_envs,
|
num_test_envs,
|
||||||
|
create_watch_env,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -228,8 +250,14 @@ class ContinuousEnvironments(Environments):
|
|||||||
class DiscreteEnvironments(Environments):
|
class DiscreteEnvironments(Environments):
|
||||||
"""Represents (vectorized) discrete environments."""
|
"""Represents (vectorized) discrete environments."""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
def __init__(
|
||||||
super().__init__(env, train_envs, test_envs)
|
self,
|
||||||
|
env: gym.Env,
|
||||||
|
train_envs: BaseVectorEnv,
|
||||||
|
test_envs: BaseVectorEnv,
|
||||||
|
watch_env: BaseVectorEnv | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(env, train_envs, test_envs, watch_env)
|
||||||
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
|
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
|
||||||
self.action_shape = env.action_space.shape or env.action_space.n # type: ignore
|
self.action_shape = env.action_space.shape or env.action_space.n # type: ignore
|
||||||
|
|
||||||
@ -239,6 +267,7 @@ class DiscreteEnvironments(Environments):
|
|||||||
venv_type: VectorEnvType,
|
venv_type: VectorEnvType,
|
||||||
num_training_envs: int,
|
num_training_envs: int,
|
||||||
num_test_envs: int,
|
num_test_envs: int,
|
||||||
|
create_watch_env: bool = False,
|
||||||
) -> "DiscreteEnvironments":
|
) -> "DiscreteEnvironments":
|
||||||
"""Creates an instance from a factory function that creates a single instance.
|
"""Creates an instance from a factory function that creates a single instance.
|
||||||
|
|
||||||
@ -246,6 +275,7 @@ class DiscreteEnvironments(Environments):
|
|||||||
:param venv_type: the vector environment type to use for parallelization
|
:param venv_type: the vector environment type to use for parallelization
|
||||||
:param num_training_envs: the number of training environments to create
|
:param num_training_envs: the number of training environments to create
|
||||||
:param num_test_envs: the number of test environments to create
|
:param num_test_envs: the number of test environments to create
|
||||||
|
:param create_watch_env: whether to create an environment for watching the agent
|
||||||
:return: the instance
|
:return: the instance
|
||||||
"""
|
"""
|
||||||
return cast(
|
return cast(
|
||||||
@ -256,6 +286,7 @@ class DiscreteEnvironments(Environments):
|
|||||||
venv_type,
|
venv_type,
|
||||||
num_training_envs,
|
num_training_envs,
|
||||||
num_test_envs,
|
num_test_envs,
|
||||||
|
create_watch_env,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -329,21 +360,28 @@ class EnvFactory(ToStringMixin, ABC):
|
|||||||
"""
|
"""
|
||||||
return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs)
|
return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs)
|
||||||
|
|
||||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
def create_envs(
|
||||||
|
self,
|
||||||
|
num_training_envs: int,
|
||||||
|
num_test_envs: int,
|
||||||
|
create_watch_env: bool = False,
|
||||||
|
) -> Environments:
|
||||||
"""Create environments for learning.
|
"""Create environments for learning.
|
||||||
|
|
||||||
:param num_training_envs: the number of training environments
|
:param num_training_envs: the number of training environments
|
||||||
:param num_test_envs: the number of test environments
|
:param num_test_envs: the number of test environments
|
||||||
|
:param create_watch_env: whether to create an environment for watching the agent
|
||||||
:return: the environments
|
:return: the environments
|
||||||
"""
|
"""
|
||||||
env = self.create_env(EnvMode.TRAIN)
|
env = self.create_env(EnvMode.TRAIN)
|
||||||
train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN)
|
train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN)
|
||||||
test_envs = self.create_venv(num_test_envs, EnvMode.TEST)
|
test_envs = self.create_venv(num_test_envs, EnvMode.TEST)
|
||||||
|
watch_env = self.create_venv(1, EnvMode.WATCH) if create_watch_env else None
|
||||||
match EnvType.from_env(env):
|
match EnvType.from_env(env):
|
||||||
case EnvType.DISCRETE:
|
case EnvType.DISCRETE:
|
||||||
return DiscreteEnvironments(env, train_envs, test_envs)
|
return DiscreteEnvironments(env, train_envs, test_envs, watch_env)
|
||||||
case EnvType.CONTINUOUS:
|
case EnvType.CONTINUOUS:
|
||||||
return ContinuousEnvironments(env, train_envs, test_envs)
|
return ContinuousEnvironments(env, train_envs, test_envs, watch_env)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.data import Collector, InfoStats
|
from tianshou.data import Collector, InfoStats
|
||||||
|
from tianshou.env import BaseVectorEnv
|
||||||
from tianshou.highlevel.agent import (
|
from tianshou.highlevel.agent import (
|
||||||
A2CAgentFactory,
|
A2CAgentFactory,
|
||||||
AgentFactory,
|
AgentFactory,
|
||||||
@ -26,7 +27,7 @@ from tianshou.highlevel.agent import (
|
|||||||
TRPOAgentFactory,
|
TRPOAgentFactory,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import EnvFactory, EnvMode
|
from tianshou.highlevel.env import EnvFactory
|
||||||
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
|
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
|
||||||
from tianshou.highlevel.module.actor import (
|
from tianshou.highlevel.module.actor import (
|
||||||
ActorFactory,
|
ActorFactory,
|
||||||
@ -222,6 +223,7 @@ class Experiment(ToStringMixin):
|
|||||||
envs = self.env_factory.create_envs(
|
envs = self.env_factory.create_envs(
|
||||||
self.sampling_config.num_train_envs,
|
self.sampling_config.num_train_envs,
|
||||||
self.sampling_config.num_test_envs,
|
self.sampling_config.num_test_envs,
|
||||||
|
create_watch_env=self.config.watch,
|
||||||
)
|
)
|
||||||
log.info(f"Created {envs}")
|
log.info(f"Created {envs}")
|
||||||
|
|
||||||
@ -289,11 +291,12 @@ class Experiment(ToStringMixin):
|
|||||||
|
|
||||||
# watch agent performance
|
# watch agent performance
|
||||||
if self.config.watch:
|
if self.config.watch:
|
||||||
|
assert envs.watch_env is not None
|
||||||
log.info("Watching agent performance")
|
log.info("Watching agent performance")
|
||||||
self._watch_agent(
|
self._watch_agent(
|
||||||
self.config.watch_num_episodes,
|
self.config.watch_num_episodes,
|
||||||
policy,
|
policy,
|
||||||
self.env_factory,
|
envs.watch_env,
|
||||||
self.config.watch_render,
|
self.config.watch_render,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -303,11 +306,10 @@ class Experiment(ToStringMixin):
|
|||||||
def _watch_agent(
|
def _watch_agent(
|
||||||
num_episodes: int,
|
num_episodes: int,
|
||||||
policy: BasePolicy,
|
policy: BasePolicy,
|
||||||
env_factory: EnvFactory,
|
env: BaseVectorEnv,
|
||||||
render: float,
|
render: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
policy.eval()
|
policy.eval()
|
||||||
env = env_factory.create_env(EnvMode.WATCH)
|
|
||||||
collector = Collector(policy, env)
|
collector = Collector(policy, env)
|
||||||
result = collector.collect(n_episode=num_episodes, render=render)
|
result = collector.collect(n_episode=num_episodes, render=render)
|
||||||
assert result.returns_stat is not None # for mypy
|
assert result.returns_stat is not None # for mypy
|
||||||
|
Loading…
x
Reference in New Issue
Block a user