diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 64a469d..ae19684 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import datetime import os from collections.abc import Sequence from typing import Literal @@ -17,6 +16,7 @@ from tianshou.highlevel.optim import OptimizerFactoryRMSprop from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import A2CParams from tianshou.utils import logging +from tianshou.utils.logging import datetime_tag def main( @@ -41,8 +41,7 @@ def main( lr_decay: bool = True, max_grad_norm: float = 0.5, ): - now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") - log_name = os.path.join(task, "a2c", str(experiment_config.seed), now) + log_name = os.path.join(task, "a2c", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( num_epochs=epoch, @@ -55,7 +54,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) experiment = ( A2CExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index d4e676a..f5901dc 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import datetime import os from collections.abc import Sequence @@ -15,6 +14,7 @@ from tianshou.highlevel.experiment import ( from tianshou.highlevel.params.noise import MaxActionScaledGaussian from tianshou.highlevel.params.policy_params import DDPGParams from tianshou.utils import logging +from tianshou.utils.logging import datetime_tag def main( @@ -37,8 +37,7 @@ def main( training_num: int = 1, test_num: int = 10, ): - now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") - log_name = os.path.join(task, "ddpg", str(experiment_config.seed), now) + log_name = os.path.join(task, "ddpg", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( num_epochs=epoch, @@ -54,7 +53,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) experiment = ( DDPGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index cecbc01..c2be4c1 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -70,10 +70,11 @@ class MujocoEnvObsRmsPersistence(Persistence): class MujocoEnvFactory(EnvFactory): - def __init__(self, task: str, seed: int, sampling_config: SamplingConfig): + def __init__(self, task: str, seed: int, sampling_config: SamplingConfig, obs_norm=True): self.task = task self.sampling_config = sampling_config self.seed = seed + self.obs_norm = obs_norm def create_envs(self, config=None) -> ContinuousEnvironments: env, train_envs, test_envs = make_mujoco_env( @@ -81,7 +82,7 @@ class MujocoEnvFactory(EnvFactory): seed=self.seed, num_train_envs=self.sampling_config.num_train_envs, num_test_envs=self.sampling_config.num_test_envs, - obs_norm=True, + obs_norm=self.obs_norm, ) envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) envs.set_persistence(MujocoEnvObsRmsPersistence()) diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index ce04658..27c916b 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -56,7 +56,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) experiment = ( NPGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index f30a63d..8ad374d 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import datetime import os from collections.abc import Sequence from typing import Literal @@ -19,6 +18,7 @@ from tianshou.highlevel.params.dist_fn import ( from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import PPOParams from tianshou.utils import logging +from tianshou.utils.logging import datetime_tag def main( @@ -48,8 +48,7 @@ def main( norm_adv: bool = False, recompute_adv: bool = True, ): - now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") - log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) + log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( num_epochs=epoch, @@ -62,7 +61,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) experiment = ( PPOExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 8a924ac..52a296d 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -59,7 +59,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) experiment = ( REDQExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 1db6d86..33b1462 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -49,7 +49,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) experiment = ( PGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index c0be24d..930aacf 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -55,7 +55,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) experiment = ( SACExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 2a09409..fd68549 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import datetime import os from collections.abc import Sequence @@ -18,6 +17,7 @@ from tianshou.highlevel.params.noise import ( ) from tianshou.highlevel.params.policy_params import TD3Params from tianshou.utils import logging +from tianshou.utils.logging import datetime_tag def main( @@ -43,8 +43,7 @@ def main( training_num: int = 1, test_num: int = 10, ): - now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") - log_name = os.path.join(task, "td3", str(experiment_config.seed), now) + log_name = os.path.join(task, "td3", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( num_epochs=epoch, @@ -59,7 +58,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) experiment = ( TD3ExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 72cb2e8..7df9580 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -58,7 +58,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) experiment = ( TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)