Add some basic tests for high-level experiment builder API

This commit is contained in:
Dominik Jain 2023-10-05 19:22:04 +02:00
parent b54fcd12cb
commit 50ac385321
3 changed files with 101 additions and 0 deletions

View File

@ -0,0 +1,36 @@
import gymnasium as gym
from tianshou.env import DummyVectorEnv
from tianshou.highlevel.env import (
ContinuousEnvironments,
DiscreteEnvironments,
EnvFactory,
Environments,
)
from tianshou.highlevel.persistence import PersistableConfigProtocol
class DiscreteTestEnvFactory(EnvFactory):
def __init__(self, test_num=10, train_num=10):
self.test_num = test_num
self.train_num = train_num
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
task = "CartPole-v0"
env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)])
return DiscreteEnvironments(env, train_envs, test_envs)
class ContinuousTestEnvFactory(EnvFactory):
def __init__(self, test_num=10, train_num=10):
self.test_num = test_num
self.train_num = train_num
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
task = "Pendulum-v1"
env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)])
return ContinuousEnvironments(env, train_envs, test_envs)

View File

@ -0,0 +1,37 @@
from test.highlevel.env_factory import ContinuousTestEnvFactory
import pytest
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
A2CExperimentBuilder,
DDPGExperimentBuilder,
PPOExperimentBuilder,
RLExperimentConfig,
SACExperimentBuilder,
TD3ExperimentBuilder,
)
@pytest.mark.parametrize(
"builder_cls",
[
PPOExperimentBuilder,
A2CExperimentBuilder,
SACExperimentBuilder,
DDPGExperimentBuilder,
TD3ExperimentBuilder,
],
)
def test_experiment_builder_continuous_default_params(builder_cls):
env_factory = ContinuousTestEnvFactory()
sampling_config = RLSamplingConfig(num_epochs=1, step_per_epoch=100)
experiment_config = RLExperimentConfig()
builder = builder_cls(
experiment_config=experiment_config,
env_factory=env_factory,
sampling_config=sampling_config,
)
experiment = builder.build()
experiment.run("test")
print(experiment)

View File

@ -0,0 +1,28 @@
from test.highlevel.env_factory import DiscreteTestEnvFactory
import pytest
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
A2CExperimentBuilder,
DQNExperimentBuilder,
PPOExperimentBuilder,
RLExperimentConfig,
)
@pytest.mark.parametrize(
"builder_cls",
[PPOExperimentBuilder, A2CExperimentBuilder, DQNExperimentBuilder],
)
def test_experiment_builder_discrete_default_params(builder_cls):
env_factory = DiscreteTestEnvFactory()
sampling_config = RLSamplingConfig(num_epochs=1, step_per_epoch=100)
builder = builder_cls(
experiment_config=RLExperimentConfig(),
env_factory=env_factory,
sampling_config=sampling_config,
)
experiment = builder.build()
experiment.run("test")
print(experiment)