Tianshou/test/highlevel/env_factory.py

37 lines
1.4 KiB
Python
Raw Normal View History

import gymnasium as gym
from tianshou.env import DummyVectorEnv
from tianshou.highlevel.env import (
ContinuousEnvironments,
DiscreteEnvironments,
EnvFactory,
Environments,
)
from tianshou.highlevel.persistence import PersistableConfigProtocol
class DiscreteTestEnvFactory(EnvFactory):
def __init__(self, test_num=10, train_num=10):
self.test_num = test_num
self.train_num = train_num
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
task = "CartPole-v0"
env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)])
return DiscreteEnvironments(env, train_envs, test_envs)
class ContinuousTestEnvFactory(EnvFactory):
def __init__(self, test_num=10, train_num=10):
self.test_num = test_num
self.train_num = train_num
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
task = "Pendulum-v1"
env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)])
return ContinuousEnvironments(env, train_envs, test_envs)