37 lines
1.4 KiB
Python
37 lines
1.4 KiB
Python
|
import gymnasium as gym
|
||
|
|
||
|
from tianshou.env import DummyVectorEnv
|
||
|
from tianshou.highlevel.env import (
|
||
|
ContinuousEnvironments,
|
||
|
DiscreteEnvironments,
|
||
|
EnvFactory,
|
||
|
Environments,
|
||
|
)
|
||
|
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||
|
|
||
|
|
||
|
class DiscreteTestEnvFactory(EnvFactory):
|
||
|
def __init__(self, test_num=10, train_num=10):
|
||
|
self.test_num = test_num
|
||
|
self.train_num = train_num
|
||
|
|
||
|
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
||
|
task = "CartPole-v0"
|
||
|
env = gym.make(task)
|
||
|
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)])
|
||
|
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)])
|
||
|
return DiscreteEnvironments(env, train_envs, test_envs)
|
||
|
|
||
|
|
||
|
class ContinuousTestEnvFactory(EnvFactory):
|
||
|
def __init__(self, test_num=10, train_num=10):
|
||
|
self.test_num = test_num
|
||
|
self.train_num = train_num
|
||
|
|
||
|
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
||
|
task = "Pendulum-v1"
|
||
|
env = gym.make(task)
|
||
|
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)])
|
||
|
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)])
|
||
|
return ContinuousEnvironments(env, train_envs, test_envs)
|