Handle obs_norm setting in MuJoCo envs
This commit is contained in:
		
							parent
							
								
									80b1b1ff9d
								
							
						
					
					
						commit
						ed06ab7ff0
					
				| @ -1,6 +1,5 @@ | ||||
| #!/usr/bin/env python3 | ||||
| 
 | ||||
| import datetime | ||||
| import os | ||||
| from collections.abc import Sequence | ||||
| from typing import Literal | ||||
| @ -17,6 +16,7 @@ from tianshou.highlevel.optim import OptimizerFactoryRMSprop | ||||
| from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear | ||||
| from tianshou.highlevel.params.policy_params import A2CParams | ||||
| from tianshou.utils import logging | ||||
| from tianshou.utils.logging import datetime_tag | ||||
| 
 | ||||
| 
 | ||||
| def main( | ||||
| @ -41,8 +41,7 @@ def main( | ||||
|     lr_decay: bool = True, | ||||
|     max_grad_norm: float = 0.5, | ||||
| ): | ||||
|     now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") | ||||
|     log_name = os.path.join(task, "a2c", str(experiment_config.seed), now) | ||||
|     log_name = os.path.join(task, "a2c", str(experiment_config.seed), datetime_tag()) | ||||
| 
 | ||||
|     sampling_config = SamplingConfig( | ||||
|         num_epochs=epoch, | ||||
| @ -55,7 +54,7 @@ def main( | ||||
|         repeat_per_collect=repeat_per_collect, | ||||
|     ) | ||||
| 
 | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) | ||||
| 
 | ||||
|     experiment = ( | ||||
|         A2CExperimentBuilder(env_factory, experiment_config, sampling_config) | ||||
|  | ||||
| @ -1,6 +1,5 @@ | ||||
| #!/usr/bin/env python3 | ||||
| 
 | ||||
| import datetime | ||||
| import os | ||||
| from collections.abc import Sequence | ||||
| 
 | ||||
| @ -15,6 +14,7 @@ from tianshou.highlevel.experiment import ( | ||||
| from tianshou.highlevel.params.noise import MaxActionScaledGaussian | ||||
| from tianshou.highlevel.params.policy_params import DDPGParams | ||||
| from tianshou.utils import logging | ||||
| from tianshou.utils.logging import datetime_tag | ||||
| 
 | ||||
| 
 | ||||
| def main( | ||||
| @ -37,8 +37,7 @@ def main( | ||||
|     training_num: int = 1, | ||||
|     test_num: int = 10, | ||||
| ): | ||||
|     now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") | ||||
|     log_name = os.path.join(task, "ddpg", str(experiment_config.seed), now) | ||||
|     log_name = os.path.join(task, "ddpg", str(experiment_config.seed), datetime_tag()) | ||||
| 
 | ||||
|     sampling_config = SamplingConfig( | ||||
|         num_epochs=epoch, | ||||
| @ -54,7 +53,7 @@ def main( | ||||
|         start_timesteps_random=True, | ||||
|     ) | ||||
| 
 | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) | ||||
| 
 | ||||
|     experiment = ( | ||||
|         DDPGExperimentBuilder(env_factory, experiment_config, sampling_config) | ||||
|  | ||||
| @ -70,10 +70,11 @@ class MujocoEnvObsRmsPersistence(Persistence): | ||||
| 
 | ||||
| 
 | ||||
| class MujocoEnvFactory(EnvFactory): | ||||
|     def __init__(self, task: str, seed: int, sampling_config: SamplingConfig): | ||||
|     def __init__(self, task: str, seed: int, sampling_config: SamplingConfig, obs_norm=True): | ||||
|         self.task = task | ||||
|         self.sampling_config = sampling_config | ||||
|         self.seed = seed | ||||
|         self.obs_norm = obs_norm | ||||
| 
 | ||||
|     def create_envs(self, config=None) -> ContinuousEnvironments: | ||||
|         env, train_envs, test_envs = make_mujoco_env( | ||||
| @ -81,7 +82,7 @@ class MujocoEnvFactory(EnvFactory): | ||||
|             seed=self.seed, | ||||
|             num_train_envs=self.sampling_config.num_train_envs, | ||||
|             num_test_envs=self.sampling_config.num_test_envs, | ||||
|             obs_norm=True, | ||||
|             obs_norm=self.obs_norm, | ||||
|         ) | ||||
|         envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) | ||||
|         envs.set_persistence(MujocoEnvObsRmsPersistence()) | ||||
|  | ||||
| @ -56,7 +56,7 @@ def main( | ||||
|         repeat_per_collect=repeat_per_collect, | ||||
|     ) | ||||
| 
 | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) | ||||
| 
 | ||||
|     experiment = ( | ||||
|         NPGExperimentBuilder(env_factory, experiment_config, sampling_config) | ||||
|  | ||||
| @ -1,6 +1,5 @@ | ||||
| #!/usr/bin/env python3 | ||||
| 
 | ||||
| import datetime | ||||
| import os | ||||
| from collections.abc import Sequence | ||||
| from typing import Literal | ||||
| @ -19,6 +18,7 @@ from tianshou.highlevel.params.dist_fn import ( | ||||
| from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear | ||||
| from tianshou.highlevel.params.policy_params import PPOParams | ||||
| from tianshou.utils import logging | ||||
| from tianshou.utils.logging import datetime_tag | ||||
| 
 | ||||
| 
 | ||||
| def main( | ||||
| @ -48,8 +48,7 @@ def main( | ||||
|     norm_adv: bool = False, | ||||
|     recompute_adv: bool = True, | ||||
| ): | ||||
|     now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") | ||||
|     log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) | ||||
|     log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag()) | ||||
| 
 | ||||
|     sampling_config = SamplingConfig( | ||||
|         num_epochs=epoch, | ||||
| @ -62,7 +61,7 @@ def main( | ||||
|         repeat_per_collect=repeat_per_collect, | ||||
|     ) | ||||
| 
 | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) | ||||
| 
 | ||||
|     experiment = ( | ||||
|         PPOExperimentBuilder(env_factory, experiment_config, sampling_config) | ||||
|  | ||||
| @ -59,7 +59,7 @@ def main( | ||||
|         start_timesteps_random=True, | ||||
|     ) | ||||
| 
 | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) | ||||
| 
 | ||||
|     experiment = ( | ||||
|         REDQExperimentBuilder(env_factory, experiment_config, sampling_config) | ||||
|  | ||||
| @ -49,7 +49,7 @@ def main( | ||||
|         repeat_per_collect=repeat_per_collect, | ||||
|     ) | ||||
| 
 | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) | ||||
| 
 | ||||
|     experiment = ( | ||||
|         PGExperimentBuilder(env_factory, experiment_config, sampling_config) | ||||
|  | ||||
| @ -55,7 +55,7 @@ def main( | ||||
|         start_timesteps_random=True, | ||||
|     ) | ||||
| 
 | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) | ||||
| 
 | ||||
|     experiment = ( | ||||
|         SACExperimentBuilder(env_factory, experiment_config, sampling_config) | ||||
|  | ||||
| @ -1,6 +1,5 @@ | ||||
| #!/usr/bin/env python3 | ||||
| 
 | ||||
| import datetime | ||||
| import os | ||||
| from collections.abc import Sequence | ||||
| 
 | ||||
| @ -18,6 +17,7 @@ from tianshou.highlevel.params.noise import ( | ||||
| ) | ||||
| from tianshou.highlevel.params.policy_params import TD3Params | ||||
| from tianshou.utils import logging | ||||
| from tianshou.utils.logging import datetime_tag | ||||
| 
 | ||||
| 
 | ||||
| def main( | ||||
| @ -43,8 +43,7 @@ def main( | ||||
|     training_num: int = 1, | ||||
|     test_num: int = 10, | ||||
| ): | ||||
|     now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") | ||||
|     log_name = os.path.join(task, "td3", str(experiment_config.seed), now) | ||||
|     log_name = os.path.join(task, "td3", str(experiment_config.seed), datetime_tag()) | ||||
| 
 | ||||
|     sampling_config = SamplingConfig( | ||||
|         num_epochs=epoch, | ||||
| @ -59,7 +58,7 @@ def main( | ||||
|         start_timesteps_random=True, | ||||
|     ) | ||||
| 
 | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) | ||||
| 
 | ||||
|     experiment = ( | ||||
|         TD3ExperimentBuilder(env_factory, experiment_config, sampling_config) | ||||
|  | ||||
| @ -58,7 +58,7 @@ def main( | ||||
|         repeat_per_collect=repeat_per_collect, | ||||
|     ) | ||||
| 
 | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) | ||||
| 
 | ||||
|     experiment = ( | ||||
|         TRPOExperimentBuilder(env_factory, experiment_config, sampling_config) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user