EfficientZeroV2/ez/envs/__init__.py
“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

183 lines
5.2 KiB
Python

import os
import dmc2gym
from gym.wrappers import Monitor
from .gym import GymWrapper
from .atari import AtariWrapper
from .dmc import DMCWrapper
from .wrapper import *
import random
from dm_env import specs
from ez.utils.format import arr_to_str
def make_envs(game_setting, game_name, num_envs, seed, save_path=None, **kwargs):
assert game_setting in ['Atari', 'DMC', 'Gym']
if game_setting == 'Atari':
_env_fn = make_atari
elif game_setting == 'Gym':
_env_fn = make_gym
elif game_setting == 'DMC':
_env_fn = make_dmc
else:
raise NotImplementedError()
if game_setting == 'DMC':
seed = random.randint(1, 1000)
envs = [_env_fn(game_name,
seed=i + seed,
# seed=seed,
save_path=save_path, **kwargs) for i in range(num_envs)]
return envs
def make_env(game_setting, game_name, num_envs, seed, save_path=None, **kwargs):
assert game_setting in ['Atari', 'DMC', 'Gym']
if game_setting == 'Atari':
_env_fn = make_atari
elif game_setting == 'Gym':
_env_fn = make_gym
elif game_setting == 'DMC':
_env_fn = make_dmc
else:
raise NotImplementedError()
seed = random.randint(1, 1000)
env = _env_fn(game_name, seed=seed, save_path=save_path, **kwargs)
return env
def make_atari(game_name, seed, save_path=None, **kwargs):
"""Make Atari games
Parameters
----------
game_name: str
name of game (Such as Breakout, Pong)
kwargs: dict
skip: int
frame skip
obs_shape: (int, int)
observation shape
gray_scale: bool
use gray observation or rgb observation
seed: int
seed of env
max_episode_steps: int
max moves for an episode
save_path: str
the path of saved videos; do not save video if None
:param seed:
"""
# params
env_id = game_name + 'NoFrameskip-v4'
gray_scale = kwargs.get('gray_scale')
obs_to_string = kwargs.get('obs_to_string')
skip = kwargs['n_skip'] if kwargs.get('n_skip') else 4
obs_shape = kwargs['obs_shape'] if kwargs.get('obs_shape') else [3, 96, 96]
max_episode_steps = kwargs['max_episode_steps'] if kwargs.get('max_episode_steps') else 108000 // skip
episodic_life = kwargs.get('episodic_life')
clip_reward = kwargs.get('clip_reward')
env = gym.make(env_id)
# set seed
env.seed(seed)
# random restart
env = NoopResetEnv(env, noop_max=30)
# frame skip
env = MaxAndSkipEnv(env, skip=skip)
# episodic trajectory
if episodic_life:
env = EpisodicLifeEnv(env)
# reshape size and gray scale
env = WarpFrame(env, width=obs_shape[1], height=obs_shape[2], grayscale=gray_scale)
# set max limit
env = TimeLimit(env, max_episode_steps=max_episode_steps)
# save video to given
if save_path:
env = Monitor(env, directory=save_path, force=True)
# your wrapper
env = AtariWrapper(env, obs_to_string=obs_to_string, clip_reward=clip_reward)
return env
def make_gym(game_name, seed, save_path=None, **kwargs):
save_path = kwargs.get('save_path')
obs_to_string = kwargs.get('obs_to_string')
skip = kwargs['n_skip'] if kwargs.get('n_skip') else 4
env = gym.make(game_name)
env = GymWrapper(env, obs_to_string=obs_to_string)
# frame skip
env = MaxAndSkipEnv(env, skip=skip)
# set seed
env.seed(seed)
# save video to given
if save_path:
env = Monitor(env, directory=save_path, force=True)
env = GymWrapper(env, obs_to_string=obs_to_string)
return env
def make_dmc(game_name, seed, save_path=None, **kwargs):
"""Make Atari games
Parameters
----------
game_name: str
name of game (Such as Breakout, Pong)
kwargs: dict
image_based: bool
observation is image or state
"""
# params
if 'CMU' in game_name:
domain_name, task_name = game_name.rsplit('_', 1)
else:
domain_name, task_name = game_name.split('_', 1)
image_based = kwargs.get('image_based')
obs_shape = kwargs['obs_shape'] if kwargs.get('obs_shape') else [3, 96, 96]
skip = kwargs['n_skip'] if kwargs.get('n_skip') else 2
max_episode_steps = kwargs['max_episode_steps'] // skip
clip_reward = kwargs.get('clip_reward')
obs_to_string = kwargs.get('obs_to_string')
# fix the bug of env (from the paper DrQv2)
camera_id = 2 if 'quadruped' in domain_name else 0
# # make env
env = dmc2gym.make(
domain_name=domain_name,
task_name=task_name,
seed=seed,
visualize_reward=False,
from_pixels=image_based,
height=obs_shape[1] if image_based else 96,
width=obs_shape[1] if image_based else 96,
frame_skip=skip,
channels_first=False,
camera_id=camera_id,
# time_limit=max_episode_steps,
)
env = TimeLimit(env, max_episode_steps=max_episode_steps)
# save video to given
if save_path:
env = Monitor(env, directory=save_path, force=True)
# your wrapper
env = DMCWrapper(env, obs_to_string=obs_to_string, clip_reward=clip_reward)
return env