Tianshou/test/highlevel/test_discrete.py
Dominik Jain 6cbee188b8 Change interface of EnvFactory to ensure that configuration
of number of environments in SamplingConfig is used
(values are now passed to factory method)

This is clearer and removes the need to pass otherwise
unnecessary configuration to environment factories at
construction
2023-10-19 11:37:20 +02:00

42 lines
1.0 KiB
Python

from test.highlevel.env_factory import DiscreteTestEnvFactory
import pytest
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
A2CExperimentBuilder,
DiscreteSACExperimentBuilder,
DQNExperimentBuilder,
ExperimentConfig,
IQNExperimentBuilder,
PPOExperimentBuilder,
)
@pytest.mark.parametrize(
"builder_cls",
[
PPOExperimentBuilder,
A2CExperimentBuilder,
DQNExperimentBuilder,
DiscreteSACExperimentBuilder,
IQNExperimentBuilder,
],
)
def test_experiment_builder_discrete_default_params(builder_cls):
env_factory = DiscreteTestEnvFactory()
sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=100,
num_train_envs=2,
num_test_envs=2,
)
builder = builder_cls(
experiment_config=ExperimentConfig(persistence_enabled=False),
env_factory=env_factory,
sampling_config=sampling_config,
)
experiment = builder.build()
experiment.run("test")
print(experiment)