Implement make_atari_env via AtariEnvFactory, eliminating duplication

This commit is contained in:
Dominik Jain 2024-01-11 12:43:05 +01:00
parent 19a98c3b2a
commit 63269fe198

View File

@ -9,7 +9,6 @@ import gymnasium as gym
import numpy as np
from gymnasium import Env
from tianshou.env import ShmemVectorEnv
from tianshou.highlevel.env import (
EnvFactoryGymnasium,
EnvMode,
@ -324,68 +323,23 @@ def wrap_deepmind(
return env
def make_atari_env(task, seed, training_num, test_num, **kwargs):
def make_atari_env(
task,
seed,
training_num,
test_num,
scale: int | bool = False,
frame_stack: int = 4,
):
"""Wrapper function for Atari env.
If EnvPool is installed, it will automatically switch to EnvPool's Atari env.
:return: a tuple of (single env, training envs, test envs).
"""
if envpool is not None:
if kwargs.get("scale", 0):
warnings.warn(
"EnvPool does not include ScaledFloatFrame wrapper, "
"please set `x = x / 255.0` inside CNN network's forward function.",
)
# parameters convertion
train_envs = env = envpool.make_gymnasium(
task.replace("NoFrameskip-v4", "-v5"),
num_envs=training_num,
seed=seed,
episodic_life=True,
reward_clip=True,
stack_num=kwargs.get("frame_stack", 4),
)
test_envs = envpool.make_gymnasium(
task.replace("NoFrameskip-v4", "-v5"),
num_envs=test_num,
seed=seed,
episodic_life=False,
reward_clip=False,
stack_num=kwargs.get("frame_stack", 4),
)
else:
assert "NoFrameskip" in task
warnings.warn(
"Recommend using envpool (pip install envpool) to run Atari games more efficiently.",
)
env = wrap_deepmind(task, **kwargs)
train_envs = ShmemVectorEnv(
[
lambda: wrap_deepmind(
gym.make(task),
episode_life=True,
clip_rewards=True,
**kwargs,
)
for _ in range(training_num)
],
)
test_envs = ShmemVectorEnv(
[
lambda: wrap_deepmind(
gym.make(task),
episode_life=False,
clip_rewards=False,
**kwargs,
)
for _ in range(test_num)
],
)
env.seed(seed)
train_envs.seed(seed)
test_envs.seed(seed)
return env, train_envs, test_envs
env_factory = AtariEnvFactory(task, seed, frame_stack, scale=bool(scale))
envs = env_factory.create_envs(training_num, test_num)
return envs.env, envs.train_envs, envs.test_envs
class AtariEnvFactory(EnvFactoryGymnasium):