From 63269fe198e50755c30c2e4b2aa97010bf849405 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 11 Jan 2024 12:43:05 +0100 Subject: [PATCH] Implement make_atari_env via AtariEnvFactory, eliminating duplication --- examples/atari/atari_wrapper.py | 68 ++++++--------------------------- 1 file changed, 11 insertions(+), 57 deletions(-) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index a7d7ea5..1e94068 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -9,7 +9,6 @@ import gymnasium as gym import numpy as np from gymnasium import Env -from tianshou.env import ShmemVectorEnv from tianshou.highlevel.env import ( EnvFactoryGymnasium, EnvMode, @@ -324,68 +323,23 @@ def wrap_deepmind( return env -def make_atari_env(task, seed, training_num, test_num, **kwargs): +def make_atari_env( + task, + seed, + training_num, + test_num, + scale: int | bool = False, + frame_stack: int = 4, +): """Wrapper function for Atari env. If EnvPool is installed, it will automatically switch to EnvPool's Atari env. :return: a tuple of (single env, training envs, test envs). """ - if envpool is not None: - if kwargs.get("scale", 0): - warnings.warn( - "EnvPool does not include ScaledFloatFrame wrapper, " - "please set `x = x / 255.0` inside CNN network's forward function.", - ) - # parameters convertion - train_envs = env = envpool.make_gymnasium( - task.replace("NoFrameskip-v4", "-v5"), - num_envs=training_num, - seed=seed, - episodic_life=True, - reward_clip=True, - stack_num=kwargs.get("frame_stack", 4), - ) - test_envs = envpool.make_gymnasium( - task.replace("NoFrameskip-v4", "-v5"), - num_envs=test_num, - seed=seed, - episodic_life=False, - reward_clip=False, - stack_num=kwargs.get("frame_stack", 4), - ) - else: - assert "NoFrameskip" in task - warnings.warn( - "Recommend using envpool (pip install envpool) to run Atari games more efficiently.", - ) - env = wrap_deepmind(task, **kwargs) - train_envs = ShmemVectorEnv( - [ - lambda: wrap_deepmind( - gym.make(task), - episode_life=True, - clip_rewards=True, - **kwargs, - ) - for _ in range(training_num) - ], - ) - test_envs = ShmemVectorEnv( - [ - lambda: wrap_deepmind( - gym.make(task), - episode_life=False, - clip_rewards=False, - **kwargs, - ) - for _ in range(test_num) - ], - ) - env.seed(seed) - train_envs.seed(seed) - test_envs.seed(seed) - return env, train_envs, test_envs + env_factory = AtariEnvFactory(task, seed, frame_stack, scale=bool(scale)) + envs = env_factory.create_envs(training_num, test_num) + return envs.env, envs.train_envs, envs.test_envs class AtariEnvFactory(EnvFactoryGymnasium):