2023-10-24 12:12:38 +02:00
|
|
|
from test.highlevel.env_factory import ContinuousTestEnvFactory, DiscreteTestEnvFactory
|
2023-10-05 19:22:04 +02:00
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
2023-10-06 13:50:23 +02:00
|
|
|
from tianshou.highlevel.config import SamplingConfig
|
2023-10-05 19:22:04 +02:00
|
|
|
from tianshou.highlevel.experiment import (
|
|
|
|
A2CExperimentBuilder,
|
|
|
|
DDPGExperimentBuilder,
|
2023-10-24 12:12:38 +02:00
|
|
|
DiscreteSACExperimentBuilder,
|
|
|
|
DQNExperimentBuilder,
|
2024-02-06 14:24:30 +01:00
|
|
|
ExperimentBuilder,
|
2023-10-06 13:50:23 +02:00
|
|
|
ExperimentConfig,
|
2023-10-24 12:12:38 +02:00
|
|
|
IQNExperimentBuilder,
|
2023-10-11 16:07:34 +02:00
|
|
|
PGExperimentBuilder,
|
2023-10-06 13:53:45 +02:00
|
|
|
PPOExperimentBuilder,
|
2023-10-11 16:07:34 +02:00
|
|
|
REDQExperimentBuilder,
|
2023-10-05 19:22:04 +02:00
|
|
|
SACExperimentBuilder,
|
|
|
|
TD3ExperimentBuilder,
|
2023-10-11 16:07:34 +02:00
|
|
|
TRPOExperimentBuilder,
|
2023-10-05 19:22:04 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"builder_cls",
|
|
|
|
[
|
|
|
|
PPOExperimentBuilder,
|
|
|
|
A2CExperimentBuilder,
|
|
|
|
SACExperimentBuilder,
|
|
|
|
DDPGExperimentBuilder,
|
|
|
|
TD3ExperimentBuilder,
|
2023-10-11 16:07:34 +02:00
|
|
|
# NPGExperimentBuilder, # TODO test fails non-deterministically
|
|
|
|
REDQExperimentBuilder,
|
|
|
|
TRPOExperimentBuilder,
|
|
|
|
PGExperimentBuilder,
|
2023-10-05 19:22:04 +02:00
|
|
|
],
|
|
|
|
)
|
2024-02-06 14:24:30 +01:00
|
|
|
def test_experiment_builder_continuous_default_params(builder_cls: type[ExperimentBuilder]) -> None:
|
2023-10-05 19:22:04 +02:00
|
|
|
env_factory = ContinuousTestEnvFactory()
|
2023-10-18 23:55:23 +02:00
|
|
|
sampling_config = SamplingConfig(
|
|
|
|
num_epochs=1,
|
|
|
|
step_per_epoch=100,
|
|
|
|
num_train_envs=2,
|
|
|
|
num_test_envs=2,
|
|
|
|
)
|
2023-10-12 17:40:16 +02:00
|
|
|
experiment_config = ExperimentConfig(persistence_enabled=False)
|
2023-10-05 19:22:04 +02:00
|
|
|
builder = builder_cls(
|
|
|
|
experiment_config=experiment_config,
|
|
|
|
env_factory=env_factory,
|
|
|
|
sampling_config=sampling_config,
|
|
|
|
)
|
|
|
|
experiment = builder.build()
|
|
|
|
experiment.run("test")
|
|
|
|
print(experiment)
|
2023-10-24 12:12:38 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"builder_cls",
|
|
|
|
[
|
|
|
|
PPOExperimentBuilder,
|
|
|
|
A2CExperimentBuilder,
|
|
|
|
DQNExperimentBuilder,
|
|
|
|
DiscreteSACExperimentBuilder,
|
|
|
|
IQNExperimentBuilder,
|
|
|
|
],
|
|
|
|
)
|
2024-02-06 14:24:30 +01:00
|
|
|
def test_experiment_builder_discrete_default_params(builder_cls: type[ExperimentBuilder]) -> None:
|
2023-10-24 12:12:38 +02:00
|
|
|
env_factory = DiscreteTestEnvFactory()
|
|
|
|
sampling_config = SamplingConfig(
|
|
|
|
num_epochs=1,
|
|
|
|
step_per_epoch=100,
|
|
|
|
num_train_envs=2,
|
|
|
|
num_test_envs=2,
|
|
|
|
)
|
|
|
|
builder = builder_cls(
|
|
|
|
experiment_config=ExperimentConfig(persistence_enabled=False),
|
|
|
|
env_factory=env_factory,
|
|
|
|
sampling_config=sampling_config,
|
|
|
|
)
|
|
|
|
experiment = builder.build()
|
|
|
|
experiment.run("test")
|
|
|
|
print(experiment)
|