Implement make_atari_env via AtariEnvFactory, eliminating duplication
This commit is contained in:
parent
19a98c3b2a
commit
63269fe198
@ -9,7 +9,6 @@ import gymnasium as gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium import Env
|
from gymnasium import Env
|
||||||
|
|
||||||
from tianshou.env import ShmemVectorEnv
|
|
||||||
from tianshou.highlevel.env import (
|
from tianshou.highlevel.env import (
|
||||||
EnvFactoryGymnasium,
|
EnvFactoryGymnasium,
|
||||||
EnvMode,
|
EnvMode,
|
||||||
@ -324,68 +323,23 @@ def wrap_deepmind(
|
|||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
def make_atari_env(task, seed, training_num, test_num, **kwargs):
|
def make_atari_env(
|
||||||
|
task,
|
||||||
|
seed,
|
||||||
|
training_num,
|
||||||
|
test_num,
|
||||||
|
scale: int | bool = False,
|
||||||
|
frame_stack: int = 4,
|
||||||
|
):
|
||||||
"""Wrapper function for Atari env.
|
"""Wrapper function for Atari env.
|
||||||
|
|
||||||
If EnvPool is installed, it will automatically switch to EnvPool's Atari env.
|
If EnvPool is installed, it will automatically switch to EnvPool's Atari env.
|
||||||
|
|
||||||
:return: a tuple of (single env, training envs, test envs).
|
:return: a tuple of (single env, training envs, test envs).
|
||||||
"""
|
"""
|
||||||
if envpool is not None:
|
env_factory = AtariEnvFactory(task, seed, frame_stack, scale=bool(scale))
|
||||||
if kwargs.get("scale", 0):
|
envs = env_factory.create_envs(training_num, test_num)
|
||||||
warnings.warn(
|
return envs.env, envs.train_envs, envs.test_envs
|
||||||
"EnvPool does not include ScaledFloatFrame wrapper, "
|
|
||||||
"please set `x = x / 255.0` inside CNN network's forward function.",
|
|
||||||
)
|
|
||||||
# parameters convertion
|
|
||||||
train_envs = env = envpool.make_gymnasium(
|
|
||||||
task.replace("NoFrameskip-v4", "-v5"),
|
|
||||||
num_envs=training_num,
|
|
||||||
seed=seed,
|
|
||||||
episodic_life=True,
|
|
||||||
reward_clip=True,
|
|
||||||
stack_num=kwargs.get("frame_stack", 4),
|
|
||||||
)
|
|
||||||
test_envs = envpool.make_gymnasium(
|
|
||||||
task.replace("NoFrameskip-v4", "-v5"),
|
|
||||||
num_envs=test_num,
|
|
||||||
seed=seed,
|
|
||||||
episodic_life=False,
|
|
||||||
reward_clip=False,
|
|
||||||
stack_num=kwargs.get("frame_stack", 4),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert "NoFrameskip" in task
|
|
||||||
warnings.warn(
|
|
||||||
"Recommend using envpool (pip install envpool) to run Atari games more efficiently.",
|
|
||||||
)
|
|
||||||
env = wrap_deepmind(task, **kwargs)
|
|
||||||
train_envs = ShmemVectorEnv(
|
|
||||||
[
|
|
||||||
lambda: wrap_deepmind(
|
|
||||||
gym.make(task),
|
|
||||||
episode_life=True,
|
|
||||||
clip_rewards=True,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
for _ in range(training_num)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
test_envs = ShmemVectorEnv(
|
|
||||||
[
|
|
||||||
lambda: wrap_deepmind(
|
|
||||||
gym.make(task),
|
|
||||||
episode_life=False,
|
|
||||||
clip_rewards=False,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
for _ in range(test_num)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
env.seed(seed)
|
|
||||||
train_envs.seed(seed)
|
|
||||||
test_envs.seed(seed)
|
|
||||||
return env, train_envs, test_envs
|
|
||||||
|
|
||||||
|
|
||||||
class AtariEnvFactory(EnvFactoryGymnasium):
|
class AtariEnvFactory(EnvFactoryGymnasium):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user