Tianshou/test/highlevel/test_continuous.py
Dominik Jain 76e870207d Improve persistence handling
* Add persistence/restoration of Experiment instance
* Add file logging in experiment
* Allow all persistence/logging to be disabled
* Disable persistence in tests
2023-10-18 20:44:18 +02:00

45 lines
1.2 KiB
Python

from test.highlevel.env_factory import ContinuousTestEnvFactory
import pytest
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
A2CExperimentBuilder,
DDPGExperimentBuilder,
ExperimentConfig,
PGExperimentBuilder,
PPOExperimentBuilder,
REDQExperimentBuilder,
SACExperimentBuilder,
TD3ExperimentBuilder,
TRPOExperimentBuilder,
)
@pytest.mark.parametrize(
"builder_cls",
[
PPOExperimentBuilder,
A2CExperimentBuilder,
SACExperimentBuilder,
DDPGExperimentBuilder,
TD3ExperimentBuilder,
# NPGExperimentBuilder, # TODO test fails non-deterministically
REDQExperimentBuilder,
TRPOExperimentBuilder,
PGExperimentBuilder,
],
)
def test_experiment_builder_continuous_default_params(builder_cls):
env_factory = ContinuousTestEnvFactory()
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
experiment_config = ExperimentConfig(persistence_enabled=False)
builder = builder_cls(
experiment_config=experiment_config,
env_factory=env_factory,
sampling_config=sampling_config,
)
experiment = builder.build()
experiment.run("test")
print(experiment)