2023-10-05 19:22:04 +02:00
|
|
|
from tianshou.highlevel.env import (
|
2024-01-16 12:22:07 +01:00
|
|
|
EnvFactoryRegistered,
|
2024-01-10 15:37:58 +01:00
|
|
|
VectorEnvType,
|
2023-10-05 19:22:04 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-01-16 12:22:07 +01:00
|
|
|
class DiscreteTestEnvFactory(EnvFactoryRegistered):
|
2024-02-06 14:24:30 +01:00
|
|
|
def __init__(self) -> None:
|
2024-04-21 01:25:33 +02:00
|
|
|
super().__init__(
|
|
|
|
task="CartPole-v0",
|
|
|
|
train_seed=42,
|
|
|
|
test_seed=1337,
|
|
|
|
venv_type=VectorEnvType.DUMMY,
|
|
|
|
)
|
2023-10-05 19:22:04 +02:00
|
|
|
|
|
|
|
|
2024-01-16 12:22:07 +01:00
|
|
|
class ContinuousTestEnvFactory(EnvFactoryRegistered):
|
2024-02-06 14:24:30 +01:00
|
|
|
def __init__(self) -> None:
|
2024-04-21 01:25:33 +02:00
|
|
|
super().__init__(
|
|
|
|
task="Pendulum-v1",
|
|
|
|
train_seed=42,
|
|
|
|
test_seed=1337,
|
|
|
|
venv_type=VectorEnvType.DUMMY,
|
|
|
|
)
|