| 
									
										
										
										
											2023-10-10 12:55:25 +02:00
										 |  |  | #!/usr/bin/env python3 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-06 14:24:30 +01:00
										 |  |  | import functools | 
					
						
							| 
									
										
										
										
											2023-10-10 12:55:25 +02:00
										 |  |  | import os | 
					
						
							|  |  |  | from collections.abc import Sequence | 
					
						
							|  |  |  | from typing import Literal | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-18 13:57:36 +02:00
										 |  |  | import torch | 
					
						
							| 
									
										
										
										
											2023-10-10 12:55:25 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | from examples.mujoco.mujoco_env import MujocoEnvFactory | 
					
						
							|  |  |  | from tianshou.highlevel.config import SamplingConfig | 
					
						
							|  |  |  | from tianshou.highlevel.experiment import ( | 
					
						
							|  |  |  |     ExperimentConfig, | 
					
						
							|  |  |  |     PGExperimentBuilder, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear | 
					
						
							|  |  |  | from tianshou.highlevel.params.policy_params import PGParams | 
					
						
							| 
									
										
										
										
											2023-11-07 10:54:22 +01:00
										 |  |  | from tianshou.utils import logging | 
					
						
							|  |  |  | from tianshou.utils.logging import datetime_tag | 
					
						
							| 
									
										
										
										
											2023-10-10 12:55:25 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def main( | 
					
						
							|  |  |  |     experiment_config: ExperimentConfig, | 
					
						
							| 
									
										
										
										
											2024-01-10 15:39:53 +01:00
										 |  |  |     task: str = "Ant-v4", | 
					
						
							| 
									
										
										
										
											2023-10-10 12:55:25 +02:00
										 |  |  |     buffer_size: int = 4096, | 
					
						
							|  |  |  |     hidden_sizes: Sequence[int] = (64, 64), | 
					
						
							|  |  |  |     lr: float = 1e-3, | 
					
						
							|  |  |  |     gamma: float = 0.99, | 
					
						
							|  |  |  |     epoch: int = 100, | 
					
						
							|  |  |  |     step_per_epoch: int = 30000, | 
					
						
							|  |  |  |     step_per_collect: int = 2048, | 
					
						
							|  |  |  |     repeat_per_collect: int = 1, | 
					
						
							| 
									
										
										
										
											2023-11-24 19:13:10 +01:00
										 |  |  |     batch_size: int | None = None, | 
					
						
							| 
									
										
										
										
											2024-02-06 17:06:38 +01:00
										 |  |  |     training_num: int = 10, | 
					
						
							| 
									
										
										
										
											2023-10-10 12:55:25 +02:00
										 |  |  |     test_num: int = 10, | 
					
						
							|  |  |  |     rew_norm: bool = True, | 
					
						
							|  |  |  |     action_bound_method: Literal["clip", "tanh"] = "tanh", | 
					
						
							|  |  |  |     lr_decay: bool = True, | 
					
						
							| 
									
										
										
										
											2024-02-06 14:24:30 +01:00
										 |  |  | ) -> None: | 
					
						
							| 
									
										
										
										
											2023-10-10 12:55:25 +02:00
										 |  |  |     log_name = os.path.join(task, "reinforce", str(experiment_config.seed), datetime_tag()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     sampling_config = SamplingConfig( | 
					
						
							|  |  |  |         num_epochs=epoch, | 
					
						
							|  |  |  |         step_per_epoch=step_per_epoch, | 
					
						
							|  |  |  |         batch_size=batch_size, | 
					
						
							|  |  |  |         num_train_envs=training_num, | 
					
						
							|  |  |  |         num_test_envs=test_num, | 
					
						
							|  |  |  |         buffer_size=buffer_size, | 
					
						
							|  |  |  |         step_per_collect=step_per_collect, | 
					
						
							|  |  |  |         repeat_per_collect=repeat_per_collect, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-18 23:55:23 +02:00
										 |  |  |     env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) | 
					
						
							| 
									
										
										
										
											2023-10-10 12:55:25 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     experiment = ( | 
					
						
							|  |  |  |         PGExperimentBuilder(env_factory, experiment_config, sampling_config) | 
					
						
							|  |  |  |         .with_pg_params( | 
					
						
							|  |  |  |             PGParams( | 
					
						
							|  |  |  |                 discount_factor=gamma, | 
					
						
							|  |  |  |                 action_bound_method=action_bound_method, | 
					
						
							|  |  |  |                 reward_normalization=rew_norm, | 
					
						
							|  |  |  |                 lr=lr, | 
					
						
							|  |  |  |                 lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) | 
					
						
							|  |  |  |                 if lr_decay | 
					
						
							|  |  |  |                 else None, | 
					
						
							|  |  |  |             ), | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-10-18 13:57:36 +02:00
										 |  |  |         .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) | 
					
						
							| 
									
										
										
										
											2023-10-10 12:55:25 +02:00
										 |  |  |         .build() | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     experiment.run(log_name) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2024-02-06 14:24:30 +01:00
										 |  |  |     run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) | 
					
						
							|  |  |  |     logging.run_cli(run_with_default_config) |