diff --git a/test/highlevel/env_factory.py b/test/highlevel/env_factory.py new file mode 100644 index 0000000..342848d --- /dev/null +++ b/test/highlevel/env_factory.py @@ -0,0 +1,36 @@ +import gymnasium as gym + +from tianshou.env import DummyVectorEnv +from tianshou.highlevel.env import ( + ContinuousEnvironments, + DiscreteEnvironments, + EnvFactory, + Environments, +) +from tianshou.highlevel.persistence import PersistableConfigProtocol + + +class DiscreteTestEnvFactory(EnvFactory): + def __init__(self, test_num=10, train_num=10): + self.test_num = test_num + self.train_num = train_num + + def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: + task = "CartPole-v0" + env = gym.make(task) + train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)]) + return DiscreteEnvironments(env, train_envs, test_envs) + + +class ContinuousTestEnvFactory(EnvFactory): + def __init__(self, test_num=10, train_num=10): + self.test_num = test_num + self.train_num = train_num + + def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: + task = "Pendulum-v1" + env = gym.make(task) + train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)]) + return ContinuousEnvironments(env, train_envs, test_envs) diff --git a/test/highlevel/test_continuous.py b/test/highlevel/test_continuous.py new file mode 100644 index 0000000..bc394d6 --- /dev/null +++ b/test/highlevel/test_continuous.py @@ -0,0 +1,37 @@ +from test.highlevel.env_factory import ContinuousTestEnvFactory + +import pytest + +from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.experiment import ( + A2CExperimentBuilder, + DDPGExperimentBuilder, + PPOExperimentBuilder, + RLExperimentConfig, + SACExperimentBuilder, + TD3ExperimentBuilder, +) + + +@pytest.mark.parametrize( + "builder_cls", + [ + PPOExperimentBuilder, + A2CExperimentBuilder, + SACExperimentBuilder, + DDPGExperimentBuilder, + TD3ExperimentBuilder, + ], +) +def test_experiment_builder_continuous_default_params(builder_cls): + env_factory = ContinuousTestEnvFactory() + sampling_config = RLSamplingConfig(num_epochs=1, step_per_epoch=100) + experiment_config = RLExperimentConfig() + builder = builder_cls( + experiment_config=experiment_config, + env_factory=env_factory, + sampling_config=sampling_config, + ) + experiment = builder.build() + experiment.run("test") + print(experiment) diff --git a/test/highlevel/test_discrete.py b/test/highlevel/test_discrete.py new file mode 100644 index 0000000..b242b85 --- /dev/null +++ b/test/highlevel/test_discrete.py @@ -0,0 +1,28 @@ +from test.highlevel.env_factory import DiscreteTestEnvFactory + +import pytest + +from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.experiment import ( + A2CExperimentBuilder, + DQNExperimentBuilder, + PPOExperimentBuilder, + RLExperimentConfig, +) + + +@pytest.mark.parametrize( + "builder_cls", + [PPOExperimentBuilder, A2CExperimentBuilder, DQNExperimentBuilder], +) +def test_experiment_builder_discrete_default_params(builder_cls): + env_factory = DiscreteTestEnvFactory() + sampling_config = RLSamplingConfig(num_epochs=1, step_per_epoch=100) + builder = builder_cls( + experiment_config=RLExperimentConfig(), + env_factory=env_factory, + sampling_config=sampling_config, + ) + experiment = builder.build() + experiment.run("test") + print(experiment)