| 
									
										
										
										
											2023-10-24 12:12:38 +02:00
										 |  |  | from test.highlevel.env_factory import ContinuousTestEnvFactory, DiscreteTestEnvFactory | 
					
						
							| 
									
										
										
										
											2023-10-05 19:22:04 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | import pytest | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-06 13:50:23 +02:00
										 |  |  | from tianshou.highlevel.config import SamplingConfig | 
					
						
							| 
									
										
										
										
											2023-10-05 19:22:04 +02:00
										 |  |  | from tianshou.highlevel.experiment import ( | 
					
						
							|  |  |  |     A2CExperimentBuilder, | 
					
						
							|  |  |  |     DDPGExperimentBuilder, | 
					
						
							| 
									
										
										
										
											2023-10-24 12:12:38 +02:00
										 |  |  |     DiscreteSACExperimentBuilder, | 
					
						
							|  |  |  |     DQNExperimentBuilder, | 
					
						
							| 
									
										
										
										
											2024-02-06 14:24:30 +01:00
										 |  |  |     ExperimentBuilder, | 
					
						
							| 
									
										
										
										
											2023-10-06 13:50:23 +02:00
										 |  |  |     ExperimentConfig, | 
					
						
							| 
									
										
										
										
											2023-10-24 12:12:38 +02:00
										 |  |  |     IQNExperimentBuilder, | 
					
						
							| 
									
										
										
										
											2023-10-11 16:07:34 +02:00
										 |  |  |     PGExperimentBuilder, | 
					
						
							| 
									
										
										
										
											2023-10-06 13:53:45 +02:00
										 |  |  |     PPOExperimentBuilder, | 
					
						
							| 
									
										
										
										
											2023-10-11 16:07:34 +02:00
										 |  |  |     REDQExperimentBuilder, | 
					
						
							| 
									
										
										
										
											2023-10-05 19:22:04 +02:00
										 |  |  |     SACExperimentBuilder, | 
					
						
							|  |  |  |     TD3ExperimentBuilder, | 
					
						
							| 
									
										
										
										
											2023-10-11 16:07:34 +02:00
										 |  |  |     TRPOExperimentBuilder, | 
					
						
							| 
									
										
										
										
											2023-10-05 19:22:04 +02:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.parametrize( | 
					
						
							|  |  |  |     "builder_cls", | 
					
						
							|  |  |  |     [ | 
					
						
							|  |  |  |         PPOExperimentBuilder, | 
					
						
							|  |  |  |         A2CExperimentBuilder, | 
					
						
							|  |  |  |         SACExperimentBuilder, | 
					
						
							|  |  |  |         DDPGExperimentBuilder, | 
					
						
							|  |  |  |         TD3ExperimentBuilder, | 
					
						
							| 
									
										
										
										
											2023-10-11 16:07:34 +02:00
										 |  |  |         # NPGExperimentBuilder,  # TODO test fails non-deterministically | 
					
						
							|  |  |  |         REDQExperimentBuilder, | 
					
						
							|  |  |  |         TRPOExperimentBuilder, | 
					
						
							|  |  |  |         PGExperimentBuilder, | 
					
						
							| 
									
										
										
										
											2023-10-05 19:22:04 +02:00
										 |  |  |     ], | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-02-06 14:24:30 +01:00
										 |  |  | def test_experiment_builder_continuous_default_params(builder_cls: type[ExperimentBuilder]) -> None: | 
					
						
							| 
									
										
										
										
											2023-10-05 19:22:04 +02:00
										 |  |  |     env_factory = ContinuousTestEnvFactory() | 
					
						
							| 
									
										
										
										
											2023-10-18 23:55:23 +02:00
										 |  |  |     sampling_config = SamplingConfig( | 
					
						
							|  |  |  |         num_epochs=1, | 
					
						
							|  |  |  |         step_per_epoch=100, | 
					
						
							|  |  |  |         num_train_envs=2, | 
					
						
							|  |  |  |         num_test_envs=2, | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-10-12 17:40:16 +02:00
										 |  |  |     experiment_config = ExperimentConfig(persistence_enabled=False) | 
					
						
							| 
									
										
										
										
											2023-10-05 19:22:04 +02:00
										 |  |  |     builder = builder_cls( | 
					
						
							|  |  |  |         experiment_config=experiment_config, | 
					
						
							|  |  |  |         env_factory=env_factory, | 
					
						
							|  |  |  |         sampling_config=sampling_config, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     experiment = builder.build() | 
					
						
							|  |  |  |     experiment.run("test") | 
					
						
							|  |  |  |     print(experiment) | 
					
						
							| 
									
										
										
										
											2023-10-24 12:12:38 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.parametrize( | 
					
						
							|  |  |  |     "builder_cls", | 
					
						
							|  |  |  |     [ | 
					
						
							|  |  |  |         PPOExperimentBuilder, | 
					
						
							|  |  |  |         A2CExperimentBuilder, | 
					
						
							|  |  |  |         DQNExperimentBuilder, | 
					
						
							|  |  |  |         DiscreteSACExperimentBuilder, | 
					
						
							|  |  |  |         IQNExperimentBuilder, | 
					
						
							|  |  |  |     ], | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-02-06 14:24:30 +01:00
										 |  |  | def test_experiment_builder_discrete_default_params(builder_cls: type[ExperimentBuilder]) -> None: | 
					
						
							| 
									
										
										
										
											2023-10-24 12:12:38 +02:00
										 |  |  |     env_factory = DiscreteTestEnvFactory() | 
					
						
							|  |  |  |     sampling_config = SamplingConfig( | 
					
						
							|  |  |  |         num_epochs=1, | 
					
						
							|  |  |  |         step_per_epoch=100, | 
					
						
							|  |  |  |         num_train_envs=2, | 
					
						
							|  |  |  |         num_test_envs=2, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     builder = builder_cls( | 
					
						
							|  |  |  |         experiment_config=ExperimentConfig(persistence_enabled=False), | 
					
						
							|  |  |  |         env_factory=env_factory, | 
					
						
							|  |  |  |         sampling_config=sampling_config, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     experiment = builder.build() | 
					
						
							|  |  |  |     experiment.run("test") | 
					
						
							|  |  |  |     print(experiment) |