Handle obs_norm setting in MuJoCo envs
This commit is contained in:
parent
80b1b1ff9d
commit
ed06ab7ff0
@ -1,6 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import datetime
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
@ -17,6 +16,7 @@ from tianshou.highlevel.optim import OptimizerFactoryRMSprop
|
|||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
from tianshou.highlevel.params.policy_params import A2CParams
|
from tianshou.highlevel.params.policy_params import A2CParams
|
||||||
from tianshou.utils import logging
|
from tianshou.utils import logging
|
||||||
|
from tianshou.utils.logging import datetime_tag
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@ -41,8 +41,7 @@ def main(
|
|||||||
lr_decay: bool = True,
|
lr_decay: bool = True,
|
||||||
max_grad_norm: float = 0.5,
|
max_grad_norm: float = 0.5,
|
||||||
):
|
):
|
||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
log_name = os.path.join(task, "a2c", str(experiment_config.seed), datetime_tag())
|
||||||
log_name = os.path.join(task, "a2c", str(experiment_config.seed), now)
|
|
||||||
|
|
||||||
sampling_config = SamplingConfig(
|
sampling_config = SamplingConfig(
|
||||||
num_epochs=epoch,
|
num_epochs=epoch,
|
||||||
@ -55,7 +54,7 @@ def main(
|
|||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)
|
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import datetime
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
@ -15,6 +14,7 @@ from tianshou.highlevel.experiment import (
|
|||||||
from tianshou.highlevel.params.noise import MaxActionScaledGaussian
|
from tianshou.highlevel.params.noise import MaxActionScaledGaussian
|
||||||
from tianshou.highlevel.params.policy_params import DDPGParams
|
from tianshou.highlevel.params.policy_params import DDPGParams
|
||||||
from tianshou.utils import logging
|
from tianshou.utils import logging
|
||||||
|
from tianshou.utils.logging import datetime_tag
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@ -37,8 +37,7 @@ def main(
|
|||||||
training_num: int = 1,
|
training_num: int = 1,
|
||||||
test_num: int = 10,
|
test_num: int = 10,
|
||||||
):
|
):
|
||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
log_name = os.path.join(task, "ddpg", str(experiment_config.seed), datetime_tag())
|
||||||
log_name = os.path.join(task, "ddpg", str(experiment_config.seed), now)
|
|
||||||
|
|
||||||
sampling_config = SamplingConfig(
|
sampling_config = SamplingConfig(
|
||||||
num_epochs=epoch,
|
num_epochs=epoch,
|
||||||
@ -54,7 +53,7 @@ def main(
|
|||||||
start_timesteps_random=True,
|
start_timesteps_random=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -70,10 +70,11 @@ class MujocoEnvObsRmsPersistence(Persistence):
|
|||||||
|
|
||||||
|
|
||||||
class MujocoEnvFactory(EnvFactory):
|
class MujocoEnvFactory(EnvFactory):
|
||||||
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig):
|
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig, obs_norm=True):
|
||||||
self.task = task
|
self.task = task
|
||||||
self.sampling_config = sampling_config
|
self.sampling_config = sampling_config
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
self.obs_norm = obs_norm
|
||||||
|
|
||||||
def create_envs(self, config=None) -> ContinuousEnvironments:
|
def create_envs(self, config=None) -> ContinuousEnvironments:
|
||||||
env, train_envs, test_envs = make_mujoco_env(
|
env, train_envs, test_envs = make_mujoco_env(
|
||||||
@ -81,7 +82,7 @@ class MujocoEnvFactory(EnvFactory):
|
|||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
num_train_envs=self.sampling_config.num_train_envs,
|
num_train_envs=self.sampling_config.num_train_envs,
|
||||||
num_test_envs=self.sampling_config.num_test_envs,
|
num_test_envs=self.sampling_config.num_test_envs,
|
||||||
obs_norm=True,
|
obs_norm=self.obs_norm,
|
||||||
)
|
)
|
||||||
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||||
envs.set_persistence(MujocoEnvObsRmsPersistence())
|
envs.set_persistence(MujocoEnvObsRmsPersistence())
|
||||||
|
@ -56,7 +56,7 @@ def main(
|
|||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import datetime
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
@ -19,6 +18,7 @@ from tianshou.highlevel.params.dist_fn import (
|
|||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
from tianshou.highlevel.params.policy_params import PPOParams
|
from tianshou.highlevel.params.policy_params import PPOParams
|
||||||
from tianshou.utils import logging
|
from tianshou.utils import logging
|
||||||
|
from tianshou.utils.logging import datetime_tag
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@ -48,8 +48,7 @@ def main(
|
|||||||
norm_adv: bool = False,
|
norm_adv: bool = False,
|
||||||
recompute_adv: bool = True,
|
recompute_adv: bool = True,
|
||||||
):
|
):
|
||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag())
|
||||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
|
||||||
|
|
||||||
sampling_config = SamplingConfig(
|
sampling_config = SamplingConfig(
|
||||||
num_epochs=epoch,
|
num_epochs=epoch,
|
||||||
@ -62,7 +61,7 @@ def main(
|
|||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -59,7 +59,7 @@ def main(
|
|||||||
start_timesteps_random=True,
|
start_timesteps_random=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)
|
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -49,7 +49,7 @@ def main(
|
|||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
PGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
PGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -55,7 +55,7 @@ def main(
|
|||||||
start_timesteps_random=True,
|
start_timesteps_random=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import datetime
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
@ -18,6 +17,7 @@ from tianshou.highlevel.params.noise import (
|
|||||||
)
|
)
|
||||||
from tianshou.highlevel.params.policy_params import TD3Params
|
from tianshou.highlevel.params.policy_params import TD3Params
|
||||||
from tianshou.utils import logging
|
from tianshou.utils import logging
|
||||||
|
from tianshou.utils.logging import datetime_tag
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@ -43,8 +43,7 @@ def main(
|
|||||||
training_num: int = 1,
|
training_num: int = 1,
|
||||||
test_num: int = 10,
|
test_num: int = 10,
|
||||||
):
|
):
|
||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
log_name = os.path.join(task, "td3", str(experiment_config.seed), datetime_tag())
|
||||||
log_name = os.path.join(task, "td3", str(experiment_config.seed), now)
|
|
||||||
|
|
||||||
sampling_config = SamplingConfig(
|
sampling_config = SamplingConfig(
|
||||||
num_epochs=epoch,
|
num_epochs=epoch,
|
||||||
@ -59,7 +58,7 @@ def main(
|
|||||||
start_timesteps_random=True,
|
start_timesteps_random=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)
|
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -58,7 +58,7 @@ def main(
|
|||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user