2023-10-05 19:22:04 +02:00
|
|
|
import gymnasium as gym
|
|
|
|
|
|
|
|
from tianshou.env import DummyVectorEnv
|
|
|
|
from tianshou.highlevel.env import (
|
|
|
|
ContinuousEnvironments,
|
|
|
|
DiscreteEnvironments,
|
|
|
|
EnvFactory,
|
|
|
|
Environments,
|
|
|
|
)
|
|
|
|
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
|
|
|
|
|
|
|
|
|
|
|
class DiscreteTestEnvFactory(EnvFactory):
|
2023-10-18 23:55:23 +02:00
|
|
|
def create_envs(
|
|
|
|
self,
|
|
|
|
num_training_envs: int,
|
|
|
|
num_test_envs: int,
|
|
|
|
config: PersistableConfigProtocol | None = None,
|
|
|
|
) -> Environments:
|
2023-10-05 19:22:04 +02:00
|
|
|
task = "CartPole-v0"
|
|
|
|
env = gym.make(task)
|
2023-10-18 23:55:23 +02:00
|
|
|
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
|
|
|
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
2023-10-05 19:22:04 +02:00
|
|
|
return DiscreteEnvironments(env, train_envs, test_envs)
|
|
|
|
|
|
|
|
|
|
|
|
class ContinuousTestEnvFactory(EnvFactory):
|
2023-10-18 23:55:23 +02:00
|
|
|
def create_envs(
|
|
|
|
self,
|
|
|
|
num_training_envs: int,
|
|
|
|
num_test_envs: int,
|
|
|
|
config: PersistableConfigProtocol | None = None,
|
|
|
|
) -> Environments:
|
2023-10-05 19:22:04 +02:00
|
|
|
task = "Pendulum-v1"
|
|
|
|
env = gym.make(task)
|
2023-10-18 23:55:23 +02:00
|
|
|
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
|
|
|
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
2023-10-05 19:22:04 +02:00
|
|
|
return ContinuousEnvironments(env, train_envs, test_envs)
|