Keep all ExperimentBuilder tests in one place
This commit is contained in:
parent
b5a891557f
commit
58466ebf5d
@ -1,41 +0,0 @@
|
|||||||
from test.highlevel.env_factory import DiscreteTestEnvFactory
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
|
||||||
from tianshou.highlevel.experiment import (
|
|
||||||
A2CExperimentBuilder,
|
|
||||||
DiscreteSACExperimentBuilder,
|
|
||||||
DQNExperimentBuilder,
|
|
||||||
ExperimentConfig,
|
|
||||||
IQNExperimentBuilder,
|
|
||||||
PPOExperimentBuilder,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"builder_cls",
|
|
||||||
[
|
|
||||||
PPOExperimentBuilder,
|
|
||||||
A2CExperimentBuilder,
|
|
||||||
DQNExperimentBuilder,
|
|
||||||
DiscreteSACExperimentBuilder,
|
|
||||||
IQNExperimentBuilder,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_experiment_builder_discrete_default_params(builder_cls):
|
|
||||||
env_factory = DiscreteTestEnvFactory()
|
|
||||||
sampling_config = SamplingConfig(
|
|
||||||
num_epochs=1,
|
|
||||||
step_per_epoch=100,
|
|
||||||
num_train_envs=2,
|
|
||||||
num_test_envs=2,
|
|
||||||
)
|
|
||||||
builder = builder_cls(
|
|
||||||
experiment_config=ExperimentConfig(persistence_enabled=False),
|
|
||||||
env_factory=env_factory,
|
|
||||||
sampling_config=sampling_config,
|
|
||||||
)
|
|
||||||
experiment = builder.build()
|
|
||||||
experiment.run("test")
|
|
||||||
print(experiment)
|
|
@ -1,4 +1,4 @@
|
|||||||
from test.highlevel.env_factory import ContinuousTestEnvFactory
|
from test.highlevel.env_factory import ContinuousTestEnvFactory, DiscreteTestEnvFactory
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -6,7 +6,10 @@ from tianshou.highlevel.config import SamplingConfig
|
|||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
A2CExperimentBuilder,
|
A2CExperimentBuilder,
|
||||||
DDPGExperimentBuilder,
|
DDPGExperimentBuilder,
|
||||||
|
DiscreteSACExperimentBuilder,
|
||||||
|
DQNExperimentBuilder,
|
||||||
ExperimentConfig,
|
ExperimentConfig,
|
||||||
|
IQNExperimentBuilder,
|
||||||
PGExperimentBuilder,
|
PGExperimentBuilder,
|
||||||
PPOExperimentBuilder,
|
PPOExperimentBuilder,
|
||||||
REDQExperimentBuilder,
|
REDQExperimentBuilder,
|
||||||
@ -47,3 +50,31 @@ def test_experiment_builder_continuous_default_params(builder_cls):
|
|||||||
experiment = builder.build()
|
experiment = builder.build()
|
||||||
experiment.run("test")
|
experiment.run("test")
|
||||||
print(experiment)
|
print(experiment)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"builder_cls",
|
||||||
|
[
|
||||||
|
PPOExperimentBuilder,
|
||||||
|
A2CExperimentBuilder,
|
||||||
|
DQNExperimentBuilder,
|
||||||
|
DiscreteSACExperimentBuilder,
|
||||||
|
IQNExperimentBuilder,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_experiment_builder_discrete_default_params(builder_cls):
|
||||||
|
env_factory = DiscreteTestEnvFactory()
|
||||||
|
sampling_config = SamplingConfig(
|
||||||
|
num_epochs=1,
|
||||||
|
step_per_epoch=100,
|
||||||
|
num_train_envs=2,
|
||||||
|
num_test_envs=2,
|
||||||
|
)
|
||||||
|
builder = builder_cls(
|
||||||
|
experiment_config=ExperimentConfig(persistence_enabled=False),
|
||||||
|
env_factory=env_factory,
|
||||||
|
sampling_config=sampling_config,
|
||||||
|
)
|
||||||
|
experiment = builder.build()
|
||||||
|
experiment.run("test")
|
||||||
|
print(experiment)
|
Loading…
x
Reference in New Issue
Block a user