Handle obs_norm setting in MuJoCo envs

This commit is contained in:
Dominik Jain 2023-10-18 13:20:26 +02:00
parent 80b1b1ff9d
commit ed06ab7ff0
10 changed files with 20 additions and 23 deletions

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python3
import datetime
import os
from collections.abc import Sequence
from typing import Literal
@ -17,6 +16,7 @@ from tianshou.highlevel.optim import OptimizerFactoryRMSprop
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import A2CParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
def main(
@ -41,8 +41,7 @@ def main(
lr_decay: bool = True,
max_grad_norm: float = 0.5,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "a2c", str(experiment_config.seed), now)
log_name = os.path.join(task, "a2c", str(experiment_config.seed), datetime_tag())
sampling_config = SamplingConfig(
num_epochs=epoch,
@ -55,7 +54,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
experiment = (
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python3
import datetime
import os
from collections.abc import Sequence
@ -15,6 +14,7 @@ from tianshou.highlevel.experiment import (
from tianshou.highlevel.params.noise import MaxActionScaledGaussian
from tianshou.highlevel.params.policy_params import DDPGParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
def main(
@ -37,8 +37,7 @@ def main(
training_num: int = 1,
test_num: int = 10,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ddpg", str(experiment_config.seed), now)
log_name = os.path.join(task, "ddpg", str(experiment_config.seed), datetime_tag())
sampling_config = SamplingConfig(
num_epochs=epoch,
@ -54,7 +53,7 @@ def main(
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
experiment = (
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -70,10 +70,11 @@ class MujocoEnvObsRmsPersistence(Persistence):
class MujocoEnvFactory(EnvFactory):
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig):
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig, obs_norm=True):
self.task = task
self.sampling_config = sampling_config
self.seed = seed
self.obs_norm = obs_norm
def create_envs(self, config=None) -> ContinuousEnvironments:
env, train_envs, test_envs = make_mujoco_env(
@ -81,7 +82,7 @@ class MujocoEnvFactory(EnvFactory):
seed=self.seed,
num_train_envs=self.sampling_config.num_train_envs,
num_test_envs=self.sampling_config.num_test_envs,
obs_norm=True,
obs_norm=self.obs_norm,
)
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
envs.set_persistence(MujocoEnvObsRmsPersistence())

View File

@ -56,7 +56,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
experiment = (
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python3
import datetime
import os
from collections.abc import Sequence
from typing import Literal
@ -19,6 +18,7 @@ from tianshou.highlevel.params.dist_fn import (
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
def main(
@ -48,8 +48,7 @@ def main(
norm_adv: bool = False,
recompute_adv: bool = True,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag())
sampling_config = SamplingConfig(
num_epochs=epoch,
@ -62,7 +61,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
experiment = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -59,7 +59,7 @@ def main(
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
experiment = (
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -49,7 +49,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
experiment = (
PGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -55,7 +55,7 @@ def main(
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
experiment = (
SACExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python3
import datetime
import os
from collections.abc import Sequence
@ -18,6 +17,7 @@ from tianshou.highlevel.params.noise import (
)
from tianshou.highlevel.params.policy_params import TD3Params
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
def main(
@ -43,8 +43,7 @@ def main(
training_num: int = 1,
test_num: int = 10,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "td3", str(experiment_config.seed), now)
log_name = os.path.join(task, "td3", str(experiment_config.seed), datetime_tag())
sampling_config = SamplingConfig(
num_epochs=epoch,
@ -59,7 +58,7 @@ def main(
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
experiment = (
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -58,7 +58,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
experiment = (
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)