Implement make_atari_env via AtariEnvFactory, eliminating duplication
This commit is contained in:
parent
19a98c3b2a
commit
63269fe198
@ -9,7 +9,6 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import Env
|
||||
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.highlevel.env import (
|
||||
EnvFactoryGymnasium,
|
||||
EnvMode,
|
||||
@ -324,68 +323,23 @@ def wrap_deepmind(
|
||||
return env
|
||||
|
||||
|
||||
def make_atari_env(task, seed, training_num, test_num, **kwargs):
|
||||
def make_atari_env(
|
||||
task,
|
||||
seed,
|
||||
training_num,
|
||||
test_num,
|
||||
scale: int | bool = False,
|
||||
frame_stack: int = 4,
|
||||
):
|
||||
"""Wrapper function for Atari env.
|
||||
|
||||
If EnvPool is installed, it will automatically switch to EnvPool's Atari env.
|
||||
|
||||
:return: a tuple of (single env, training envs, test envs).
|
||||
"""
|
||||
if envpool is not None:
|
||||
if kwargs.get("scale", 0):
|
||||
warnings.warn(
|
||||
"EnvPool does not include ScaledFloatFrame wrapper, "
|
||||
"please set `x = x / 255.0` inside CNN network's forward function.",
|
||||
)
|
||||
# parameters convertion
|
||||
train_envs = env = envpool.make_gymnasium(
|
||||
task.replace("NoFrameskip-v4", "-v5"),
|
||||
num_envs=training_num,
|
||||
seed=seed,
|
||||
episodic_life=True,
|
||||
reward_clip=True,
|
||||
stack_num=kwargs.get("frame_stack", 4),
|
||||
)
|
||||
test_envs = envpool.make_gymnasium(
|
||||
task.replace("NoFrameskip-v4", "-v5"),
|
||||
num_envs=test_num,
|
||||
seed=seed,
|
||||
episodic_life=False,
|
||||
reward_clip=False,
|
||||
stack_num=kwargs.get("frame_stack", 4),
|
||||
)
|
||||
else:
|
||||
assert "NoFrameskip" in task
|
||||
warnings.warn(
|
||||
"Recommend using envpool (pip install envpool) to run Atari games more efficiently.",
|
||||
)
|
||||
env = wrap_deepmind(task, **kwargs)
|
||||
train_envs = ShmemVectorEnv(
|
||||
[
|
||||
lambda: wrap_deepmind(
|
||||
gym.make(task),
|
||||
episode_life=True,
|
||||
clip_rewards=True,
|
||||
**kwargs,
|
||||
)
|
||||
for _ in range(training_num)
|
||||
],
|
||||
)
|
||||
test_envs = ShmemVectorEnv(
|
||||
[
|
||||
lambda: wrap_deepmind(
|
||||
gym.make(task),
|
||||
episode_life=False,
|
||||
clip_rewards=False,
|
||||
**kwargs,
|
||||
)
|
||||
for _ in range(test_num)
|
||||
],
|
||||
)
|
||||
env.seed(seed)
|
||||
train_envs.seed(seed)
|
||||
test_envs.seed(seed)
|
||||
return env, train_envs, test_envs
|
||||
env_factory = AtariEnvFactory(task, seed, frame_stack, scale=bool(scale))
|
||||
envs = env_factory.create_envs(training_num, test_num)
|
||||
return envs.env, envs.train_envs, envs.test_envs
|
||||
|
||||
|
||||
class AtariEnvFactory(EnvFactoryGymnasium):
|
||||
|
Loading…
x
Reference in New Issue
Block a user