Tianshou/test/highlevel/env_factory.py
Dominik Jain 6cbee188b8 Change interface of EnvFactory to ensure that configuration
of number of environments in SamplingConfig is used
(values are now passed to factory method)

This is clearer and removes the need to pass otherwise
unnecessary configuration to environment factories at
construction
2023-10-19 11:37:20 +02:00

39 lines
1.3 KiB
Python

import gymnasium as gym
from tianshou.env import DummyVectorEnv
from tianshou.highlevel.env import (
ContinuousEnvironments,
DiscreteEnvironments,
EnvFactory,
Environments,
)
from tianshou.highlevel.persistence import PersistableConfigProtocol
class DiscreteTestEnvFactory(EnvFactory):
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
config: PersistableConfigProtocol | None = None,
) -> Environments:
task = "CartPole-v0"
env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
return DiscreteEnvironments(env, train_envs, test_envs)
class ContinuousTestEnvFactory(EnvFactory):
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
config: PersistableConfigProtocol | None = None,
) -> Environments:
task = "Pendulum-v1"
env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
return ContinuousEnvironments(env, train_envs, test_envs)