28 lines
1.1 KiB
Python
28 lines
1.1 KiB
Python
import gymnasium as gym
|
|
|
|
from tianshou.env import DummyVectorEnv
|
|
from tianshou.highlevel.env import (
|
|
ContinuousEnvironments,
|
|
DiscreteEnvironments,
|
|
EnvFactory,
|
|
Environments,
|
|
)
|
|
|
|
|
|
class DiscreteTestEnvFactory(EnvFactory):
|
|
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
|
task = "CartPole-v0"
|
|
env = gym.make(task)
|
|
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
|
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
|
return DiscreteEnvironments(env, train_envs, test_envs)
|
|
|
|
|
|
class ContinuousTestEnvFactory(EnvFactory):
|
|
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
|
task = "Pendulum-v1"
|
|
env = gym.make(task)
|
|
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
|
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
|
return ContinuousEnvironments(env, train_envs, test_envs)
|