of number of environments in SamplingConfig is used (values are now passed to factory method) This is clearer and removes the need to pass otherwise unnecessary configuration to environment factories at construction
42 lines
1.0 KiB
Python
42 lines
1.0 KiB
Python
from test.highlevel.env_factory import DiscreteTestEnvFactory
|
|
|
|
import pytest
|
|
|
|
from tianshou.highlevel.config import SamplingConfig
|
|
from tianshou.highlevel.experiment import (
|
|
A2CExperimentBuilder,
|
|
DiscreteSACExperimentBuilder,
|
|
DQNExperimentBuilder,
|
|
ExperimentConfig,
|
|
IQNExperimentBuilder,
|
|
PPOExperimentBuilder,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"builder_cls",
|
|
[
|
|
PPOExperimentBuilder,
|
|
A2CExperimentBuilder,
|
|
DQNExperimentBuilder,
|
|
DiscreteSACExperimentBuilder,
|
|
IQNExperimentBuilder,
|
|
],
|
|
)
|
|
def test_experiment_builder_discrete_default_params(builder_cls):
|
|
env_factory = DiscreteTestEnvFactory()
|
|
sampling_config = SamplingConfig(
|
|
num_epochs=1,
|
|
step_per_epoch=100,
|
|
num_train_envs=2,
|
|
num_test_envs=2,
|
|
)
|
|
builder = builder_cls(
|
|
experiment_config=ExperimentConfig(persistence_enabled=False),
|
|
env_factory=env_factory,
|
|
sampling_config=sampling_config,
|
|
)
|
|
experiment = builder.build()
|
|
experiment.run("test")
|
|
print(experiment)
|