Add some basic tests for high-level experiment builder API
This commit is contained in:
parent
b54fcd12cb
commit
50ac385321
36
test/highlevel/env_factory.py
Normal file
36
test/highlevel/env_factory.py
Normal file
@ -0,0 +1,36 @@
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.highlevel.env import (
|
||||
ContinuousEnvironments,
|
||||
DiscreteEnvironments,
|
||||
EnvFactory,
|
||||
Environments,
|
||||
)
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||
|
||||
|
||||
class DiscreteTestEnvFactory(EnvFactory):
|
||||
def __init__(self, test_num=10, train_num=10):
|
||||
self.test_num = test_num
|
||||
self.train_num = train_num
|
||||
|
||||
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
||||
task = "CartPole-v0"
|
||||
env = gym.make(task)
|
||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)])
|
||||
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)])
|
||||
return DiscreteEnvironments(env, train_envs, test_envs)
|
||||
|
||||
|
||||
class ContinuousTestEnvFactory(EnvFactory):
|
||||
def __init__(self, test_num=10, train_num=10):
|
||||
self.test_num = test_num
|
||||
self.train_num = train_num
|
||||
|
||||
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
||||
task = "Pendulum-v1"
|
||||
env = gym.make(task)
|
||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)])
|
||||
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)])
|
||||
return ContinuousEnvironments(env, train_envs, test_envs)
|
37
test/highlevel/test_continuous.py
Normal file
37
test/highlevel/test_continuous.py
Normal file
@ -0,0 +1,37 @@
|
||||
from test.highlevel.env_factory import ContinuousTestEnvFactory
|
||||
|
||||
import pytest
|
||||
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
A2CExperimentBuilder,
|
||||
DDPGExperimentBuilder,
|
||||
PPOExperimentBuilder,
|
||||
RLExperimentConfig,
|
||||
SACExperimentBuilder,
|
||||
TD3ExperimentBuilder,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"builder_cls",
|
||||
[
|
||||
PPOExperimentBuilder,
|
||||
A2CExperimentBuilder,
|
||||
SACExperimentBuilder,
|
||||
DDPGExperimentBuilder,
|
||||
TD3ExperimentBuilder,
|
||||
],
|
||||
)
|
||||
def test_experiment_builder_continuous_default_params(builder_cls):
|
||||
env_factory = ContinuousTestEnvFactory()
|
||||
sampling_config = RLSamplingConfig(num_epochs=1, step_per_epoch=100)
|
||||
experiment_config = RLExperimentConfig()
|
||||
builder = builder_cls(
|
||||
experiment_config=experiment_config,
|
||||
env_factory=env_factory,
|
||||
sampling_config=sampling_config,
|
||||
)
|
||||
experiment = builder.build()
|
||||
experiment.run("test")
|
||||
print(experiment)
|
28
test/highlevel/test_discrete.py
Normal file
28
test/highlevel/test_discrete.py
Normal file
@ -0,0 +1,28 @@
|
||||
from test.highlevel.env_factory import DiscreteTestEnvFactory
|
||||
|
||||
import pytest
|
||||
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
A2CExperimentBuilder,
|
||||
DQNExperimentBuilder,
|
||||
PPOExperimentBuilder,
|
||||
RLExperimentConfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"builder_cls",
|
||||
[PPOExperimentBuilder, A2CExperimentBuilder, DQNExperimentBuilder],
|
||||
)
|
||||
def test_experiment_builder_discrete_default_params(builder_cls):
|
||||
env_factory = DiscreteTestEnvFactory()
|
||||
sampling_config = RLSamplingConfig(num_epochs=1, step_per_epoch=100)
|
||||
builder = builder_cls(
|
||||
experiment_config=RLExperimentConfig(),
|
||||
env_factory=env_factory,
|
||||
sampling_config=sampling_config,
|
||||
)
|
||||
experiment = builder.build()
|
||||
experiment.run("test")
|
||||
print(experiment)
|
Loading…
x
Reference in New Issue
Block a user