of number of environments in SamplingConfig is used (values are now passed to factory method) This is clearer and removes the need to pass otherwise unnecessary configuration to environment factories at construction
39 lines
1.3 KiB
Python
39 lines
1.3 KiB
Python
import gymnasium as gym
|
|
|
|
from tianshou.env import DummyVectorEnv
|
|
from tianshou.highlevel.env import (
|
|
ContinuousEnvironments,
|
|
DiscreteEnvironments,
|
|
EnvFactory,
|
|
Environments,
|
|
)
|
|
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
|
|
|
|
|
class DiscreteTestEnvFactory(EnvFactory):
|
|
def create_envs(
|
|
self,
|
|
num_training_envs: int,
|
|
num_test_envs: int,
|
|
config: PersistableConfigProtocol | None = None,
|
|
) -> Environments:
|
|
task = "CartPole-v0"
|
|
env = gym.make(task)
|
|
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
|
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
|
return DiscreteEnvironments(env, train_envs, test_envs)
|
|
|
|
|
|
class ContinuousTestEnvFactory(EnvFactory):
|
|
def create_envs(
|
|
self,
|
|
num_training_envs: int,
|
|
num_test_envs: int,
|
|
config: PersistableConfigProtocol | None = None,
|
|
) -> Environments:
|
|
task = "Pendulum-v1"
|
|
env = gym.make(task)
|
|
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
|
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
|
return ContinuousEnvironments(env, train_envs, test_envs)
|