Keep all ExperimentBuilder tests in one place

This commit is contained in:
Dominik Jain 2023-10-24 12:12:38 +02:00
parent b5a891557f
commit 58466ebf5d
2 changed files with 32 additions and 42 deletions

View File

@ -1,41 +0,0 @@
from test.highlevel.env_factory import DiscreteTestEnvFactory
import pytest
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
A2CExperimentBuilder,
DiscreteSACExperimentBuilder,
DQNExperimentBuilder,
ExperimentConfig,
IQNExperimentBuilder,
PPOExperimentBuilder,
)
@pytest.mark.parametrize(
"builder_cls",
[
PPOExperimentBuilder,
A2CExperimentBuilder,
DQNExperimentBuilder,
DiscreteSACExperimentBuilder,
IQNExperimentBuilder,
],
)
def test_experiment_builder_discrete_default_params(builder_cls):
env_factory = DiscreteTestEnvFactory()
sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=100,
num_train_envs=2,
num_test_envs=2,
)
builder = builder_cls(
experiment_config=ExperimentConfig(persistence_enabled=False),
env_factory=env_factory,
sampling_config=sampling_config,
)
experiment = builder.build()
experiment.run("test")
print(experiment)

View File

@ -1,4 +1,4 @@
from test.highlevel.env_factory import ContinuousTestEnvFactory
from test.highlevel.env_factory import ContinuousTestEnvFactory, DiscreteTestEnvFactory
import pytest
@ -6,7 +6,10 @@ from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
A2CExperimentBuilder,
DDPGExperimentBuilder,
DiscreteSACExperimentBuilder,
DQNExperimentBuilder,
ExperimentConfig,
IQNExperimentBuilder,
PGExperimentBuilder,
PPOExperimentBuilder,
REDQExperimentBuilder,
@ -47,3 +50,31 @@ def test_experiment_builder_continuous_default_params(builder_cls):
experiment = builder.build()
experiment.run("test")
print(experiment)
@pytest.mark.parametrize(
"builder_cls",
[
PPOExperimentBuilder,
A2CExperimentBuilder,
DQNExperimentBuilder,
DiscreteSACExperimentBuilder,
IQNExperimentBuilder,
],
)
def test_experiment_builder_discrete_default_params(builder_cls):
env_factory = DiscreteTestEnvFactory()
sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=100,
num_train_envs=2,
num_test_envs=2,
)
builder = builder_cls(
experiment_config=ExperimentConfig(persistence_enabled=False),
env_factory=env_factory,
sampling_config=sampling_config,
)
experiment = builder.build()
experiment.run("test")
print(experiment)