108 lines
3.5 KiB
Python
108 lines
3.5 KiB
Python
"""
|
|
Helpers for scripts like run_atari.py.
|
|
"""
|
|
|
|
import os
|
|
from mpi4py import MPI
|
|
import gym
|
|
from gym.wrappers import FlattenDictWrapper
|
|
from baselines import logger
|
|
from baselines.bench import Monitor
|
|
from baselines.common import set_global_seeds
|
|
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
|
|
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
|
|
|
|
|
def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0):
|
|
"""
|
|
Create a wrapped, monitored SubprocVecEnv for Atari.
|
|
"""
|
|
if wrapper_kwargs is None: wrapper_kwargs = {}
|
|
|
|
def make_env(rank): # pylint: disable=C0111
|
|
def _thunk():
|
|
env = make_atari(env_id)
|
|
env.seed(seed + rank)
|
|
# Monitor is a wrapper of gym env, 对环境Env进行封装, 主要添加了对episode结束时信息的记录。
|
|
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
|
|
return wrap_deepmind(env, **wrapper_kwargs)
|
|
|
|
return _thunk
|
|
|
|
set_global_seeds(seed)
|
|
|
|
# SubproVecEnv 将上面创建好的函数(_thunk)放到各个子进程中去执行
|
|
# (i + start_index) 传入不同的seed
|
|
return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
|
|
|
|
|
|
def make_mujoco_env(env_id, seed):
|
|
"""
|
|
Create a wrapped, monitored gym.Env for MuJoCo.
|
|
"""
|
|
rank = MPI.COMM_WORLD.Get_rank()
|
|
set_global_seeds(seed + 10000 * rank)
|
|
env = gym.make(env_id)
|
|
env = Monitor(env, os.path.join(logger.get_dir(), str(rank)))
|
|
env.seed(seed)
|
|
return env
|
|
|
|
|
|
def make_robotics_env(env_id, seed, rank=0):
|
|
"""
|
|
Create a wrapped, monitored gym.Env for MuJoCo.
|
|
"""
|
|
set_global_seeds(seed)
|
|
env = gym.make(env_id)
|
|
env = FlattenDictWrapper(env, ['observation', 'desired_goal'])
|
|
env = Monitor(
|
|
env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
|
|
info_keywords=('is_success',))
|
|
env.seed(seed)
|
|
return env
|
|
|
|
|
|
def arg_parser():
|
|
"""
|
|
Create an empty argparse.ArgumentParser.
|
|
"""
|
|
import argparse
|
|
return argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
|
|
|
|
def atari_arg_parser():
|
|
"""
|
|
Create an argparse.ArgumentParser for run_atari.py.
|
|
"""
|
|
parser = arg_parser()
|
|
parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
|
|
# parser.add_argument('--env', help='environment ID', default='PongNoFrameskip-v4')
|
|
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
|
parser.add_argument('--num-timesteps', type=int, default=int(10e6))
|
|
return parser
|
|
|
|
|
|
def mujoco_arg_parser():
|
|
"""
|
|
Create an argparse.ArgumentParser for run_mujoco.py.
|
|
"""
|
|
parser = arg_parser()
|
|
# parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
|
|
parser.add_argument('--env', help='environment ID', type=str, default='Swimmer-v2')
|
|
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
|
parser.add_argument('--num-timesteps', type=int, default=int(1e6))
|
|
parser.add_argument('--play', default=False, action='store_true')
|
|
parser.add_argument('--force_write', default=0, type=int)
|
|
return parser
|
|
|
|
|
|
def robotics_arg_parser():
|
|
"""
|
|
Create an argparse.ArgumentParser for run_mujoco.py.
|
|
"""
|
|
parser = arg_parser()
|
|
parser.add_argument('--env', help='environment ID', type=str, default='FetchReach-v0')
|
|
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
|
parser.add_argument('--num-timesteps', type=int, default=int(1e6))
|
|
return parser
|