From 58466ebf5d94c043c555a4293b45406d3abcaa5c Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 24 Oct 2023 12:12:38 +0200 Subject: [PATCH] Keep all ExperimentBuilder tests in one place --- test/highlevel/test_discrete.py | 41 ------------------- ...ntinuous.py => test_experiment_builder.py} | 33 ++++++++++++++- 2 files changed, 32 insertions(+), 42 deletions(-) delete mode 100644 test/highlevel/test_discrete.py rename test/highlevel/{test_continuous.py => test_experiment_builder.py} (58%) diff --git a/test/highlevel/test_discrete.py b/test/highlevel/test_discrete.py deleted file mode 100644 index 53f1dd8..0000000 --- a/test/highlevel/test_discrete.py +++ /dev/null @@ -1,41 +0,0 @@ -from test.highlevel.env_factory import DiscreteTestEnvFactory - -import pytest - -from tianshou.highlevel.config import SamplingConfig -from tianshou.highlevel.experiment import ( - A2CExperimentBuilder, - DiscreteSACExperimentBuilder, - DQNExperimentBuilder, - ExperimentConfig, - IQNExperimentBuilder, - PPOExperimentBuilder, -) - - -@pytest.mark.parametrize( - "builder_cls", - [ - PPOExperimentBuilder, - A2CExperimentBuilder, - DQNExperimentBuilder, - DiscreteSACExperimentBuilder, - IQNExperimentBuilder, - ], -) -def test_experiment_builder_discrete_default_params(builder_cls): - env_factory = DiscreteTestEnvFactory() - sampling_config = SamplingConfig( - num_epochs=1, - step_per_epoch=100, - num_train_envs=2, - num_test_envs=2, - ) - builder = builder_cls( - experiment_config=ExperimentConfig(persistence_enabled=False), - env_factory=env_factory, - sampling_config=sampling_config, - ) - experiment = builder.build() - experiment.run("test") - print(experiment) diff --git a/test/highlevel/test_continuous.py b/test/highlevel/test_experiment_builder.py similarity index 58% rename from test/highlevel/test_continuous.py rename to test/highlevel/test_experiment_builder.py index 32de680..e53c0f7 100644 --- a/test/highlevel/test_continuous.py +++ b/test/highlevel/test_experiment_builder.py @@ -1,4 +1,4 @@ -from test.highlevel.env_factory import ContinuousTestEnvFactory +from test.highlevel.env_factory import ContinuousTestEnvFactory, DiscreteTestEnvFactory import pytest @@ -6,7 +6,10 @@ from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( A2CExperimentBuilder, DDPGExperimentBuilder, + DiscreteSACExperimentBuilder, + DQNExperimentBuilder, ExperimentConfig, + IQNExperimentBuilder, PGExperimentBuilder, PPOExperimentBuilder, REDQExperimentBuilder, @@ -47,3 +50,31 @@ def test_experiment_builder_continuous_default_params(builder_cls): experiment = builder.build() experiment.run("test") print(experiment) + + +@pytest.mark.parametrize( + "builder_cls", + [ + PPOExperimentBuilder, + A2CExperimentBuilder, + DQNExperimentBuilder, + DiscreteSACExperimentBuilder, + IQNExperimentBuilder, + ], +) +def test_experiment_builder_discrete_default_params(builder_cls): + env_factory = DiscreteTestEnvFactory() + sampling_config = SamplingConfig( + num_epochs=1, + step_per_epoch=100, + num_train_envs=2, + num_test_envs=2, + ) + builder = builder_cls( + experiment_config=ExperimentConfig(persistence_enabled=False), + env_factory=env_factory, + sampling_config=sampling_config, + ) + experiment = builder.build() + experiment.run("test") + print(experiment)