Handle obs_norm setting in MuJoCo envs
This commit is contained in:
parent
80b1b1ff9d
commit
ed06ab7ff0
@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import datetime
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
@ -17,6 +16,7 @@ from tianshou.highlevel.optim import OptimizerFactoryRMSprop
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||
from tianshou.highlevel.params.policy_params import A2CParams
|
||||
from tianshou.utils import logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
|
||||
|
||||
def main(
|
||||
@ -41,8 +41,7 @@ def main(
|
||||
lr_decay: bool = True,
|
||||
max_grad_norm: float = 0.5,
|
||||
):
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "a2c", str(experiment_config.seed), now)
|
||||
log_name = os.path.join(task, "a2c", str(experiment_config.seed), datetime_tag())
|
||||
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
@ -55,7 +54,7 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
||||
|
||||
experiment = (
|
||||
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import datetime
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
|
||||
@ -15,6 +14,7 @@ from tianshou.highlevel.experiment import (
|
||||
from tianshou.highlevel.params.noise import MaxActionScaledGaussian
|
||||
from tianshou.highlevel.params.policy_params import DDPGParams
|
||||
from tianshou.utils import logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
|
||||
|
||||
def main(
|
||||
@ -37,8 +37,7 @@ def main(
|
||||
training_num: int = 1,
|
||||
test_num: int = 10,
|
||||
):
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "ddpg", str(experiment_config.seed), now)
|
||||
log_name = os.path.join(task, "ddpg", str(experiment_config.seed), datetime_tag())
|
||||
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
@ -54,7 +53,7 @@ def main(
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
|
||||
|
||||
experiment = (
|
||||
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -70,10 +70,11 @@ class MujocoEnvObsRmsPersistence(Persistence):
|
||||
|
||||
|
||||
class MujocoEnvFactory(EnvFactory):
|
||||
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig):
|
||||
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig, obs_norm=True):
|
||||
self.task = task
|
||||
self.sampling_config = sampling_config
|
||||
self.seed = seed
|
||||
self.obs_norm = obs_norm
|
||||
|
||||
def create_envs(self, config=None) -> ContinuousEnvironments:
|
||||
env, train_envs, test_envs = make_mujoco_env(
|
||||
@ -81,7 +82,7 @@ class MujocoEnvFactory(EnvFactory):
|
||||
seed=self.seed,
|
||||
num_train_envs=self.sampling_config.num_train_envs,
|
||||
num_test_envs=self.sampling_config.num_test_envs,
|
||||
obs_norm=True,
|
||||
obs_norm=self.obs_norm,
|
||||
)
|
||||
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||
envs.set_persistence(MujocoEnvObsRmsPersistence())
|
||||
|
@ -56,7 +56,7 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
||||
|
||||
experiment = (
|
||||
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import datetime
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
@ -19,6 +18,7 @@ from tianshou.highlevel.params.dist_fn import (
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||
from tianshou.highlevel.params.policy_params import PPOParams
|
||||
from tianshou.utils import logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
|
||||
|
||||
def main(
|
||||
@ -48,8 +48,7 @@ def main(
|
||||
norm_adv: bool = False,
|
||||
recompute_adv: bool = True,
|
||||
):
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag())
|
||||
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
@ -62,7 +61,7 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
||||
|
||||
experiment = (
|
||||
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -59,7 +59,7 @@ def main(
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
|
||||
|
||||
experiment = (
|
||||
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -49,7 +49,7 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
||||
|
||||
experiment = (
|
||||
PGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -55,7 +55,7 @@ def main(
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
|
||||
|
||||
experiment = (
|
||||
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import datetime
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
|
||||
@ -18,6 +17,7 @@ from tianshou.highlevel.params.noise import (
|
||||
)
|
||||
from tianshou.highlevel.params.policy_params import TD3Params
|
||||
from tianshou.utils import logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
|
||||
|
||||
def main(
|
||||
@ -43,8 +43,7 @@ def main(
|
||||
training_num: int = 1,
|
||||
test_num: int = 10,
|
||||
):
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "td3", str(experiment_config.seed), now)
|
||||
log_name = os.path.join(task, "td3", str(experiment_config.seed), datetime_tag())
|
||||
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
@ -59,7 +58,7 @@ def main(
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
|
||||
|
||||
experiment = (
|
||||
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -58,7 +58,7 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
||||
|
||||
experiment = (
|
||||
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
Loading…
x
Reference in New Issue
Block a user