Implement make_atari_env via AtariEnvFactory, eliminating duplication

This commit is contained in:
Dominik Jain 2024-01-11 12:43:05 +01:00
parent 19a98c3b2a
commit 63269fe198

View File

@ -9,7 +9,6 @@ import gymnasium as gym
import numpy as np import numpy as np
from gymnasium import Env from gymnasium import Env
from tianshou.env import ShmemVectorEnv
from tianshou.highlevel.env import ( from tianshou.highlevel.env import (
EnvFactoryGymnasium, EnvFactoryGymnasium,
EnvMode, EnvMode,
@ -324,68 +323,23 @@ def wrap_deepmind(
return env return env
def make_atari_env(task, seed, training_num, test_num, **kwargs): def make_atari_env(
task,
seed,
training_num,
test_num,
scale: int | bool = False,
frame_stack: int = 4,
):
"""Wrapper function for Atari env. """Wrapper function for Atari env.
If EnvPool is installed, it will automatically switch to EnvPool's Atari env. If EnvPool is installed, it will automatically switch to EnvPool's Atari env.
:return: a tuple of (single env, training envs, test envs). :return: a tuple of (single env, training envs, test envs).
""" """
if envpool is not None: env_factory = AtariEnvFactory(task, seed, frame_stack, scale=bool(scale))
if kwargs.get("scale", 0): envs = env_factory.create_envs(training_num, test_num)
warnings.warn( return envs.env, envs.train_envs, envs.test_envs
"EnvPool does not include ScaledFloatFrame wrapper, "
"please set `x = x / 255.0` inside CNN network's forward function.",
)
# parameters convertion
train_envs = env = envpool.make_gymnasium(
task.replace("NoFrameskip-v4", "-v5"),
num_envs=training_num,
seed=seed,
episodic_life=True,
reward_clip=True,
stack_num=kwargs.get("frame_stack", 4),
)
test_envs = envpool.make_gymnasium(
task.replace("NoFrameskip-v4", "-v5"),
num_envs=test_num,
seed=seed,
episodic_life=False,
reward_clip=False,
stack_num=kwargs.get("frame_stack", 4),
)
else:
assert "NoFrameskip" in task
warnings.warn(
"Recommend using envpool (pip install envpool) to run Atari games more efficiently.",
)
env = wrap_deepmind(task, **kwargs)
train_envs = ShmemVectorEnv(
[
lambda: wrap_deepmind(
gym.make(task),
episode_life=True,
clip_rewards=True,
**kwargs,
)
for _ in range(training_num)
],
)
test_envs = ShmemVectorEnv(
[
lambda: wrap_deepmind(
gym.make(task),
episode_life=False,
clip_rewards=False,
**kwargs,
)
for _ in range(test_num)
],
)
env.seed(seed)
train_envs.seed(seed)
test_envs.seed(seed)
return env, train_envs, test_envs
class AtariEnvFactory(EnvFactoryGymnasium): class AtariEnvFactory(EnvFactoryGymnasium):