""" Helpers for scripts like run_atari.py. """ import os from mpi4py import MPI import gym from gym.wrappers import FlattenDictWrapper from baselines import logger from baselines.bench import Monitor from baselines.common import set_global_seeds from baselines.common.atari_wrappers import make_atari, wrap_deepmind from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0): """ Create a wrapped, monitored SubprocVecEnv for Atari. """ if wrapper_kwargs is None: wrapper_kwargs = {} def make_env(rank): # pylint: disable=C0111 def _thunk(): env = make_atari(env_id) env.seed(seed + rank) # Monitor is a wrapper of gym env, 对环境Env进行封装, 主要添加了对episode结束时信息的记录。 env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank))) return wrap_deepmind(env, **wrapper_kwargs) return _thunk set_global_seeds(seed) # SubproVecEnv 将上面创建好的函数(_thunk)放到各个子进程中去执行 # (i + start_index) 传入不同的seed return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) def make_mujoco_env(env_id, seed): """ Create a wrapped, monitored gym.Env for MuJoCo. """ rank = MPI.COMM_WORLD.Get_rank() set_global_seeds(seed + 10000 * rank) env = gym.make(env_id) env = Monitor(env, os.path.join(logger.get_dir(), str(rank))) env.seed(seed) return env def make_robotics_env(env_id, seed, rank=0): """ Create a wrapped, monitored gym.Env for MuJoCo. """ set_global_seeds(seed) env = gym.make(env_id) env = FlattenDictWrapper(env, ['observation', 'desired_goal']) env = Monitor( env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), info_keywords=('is_success',)) env.seed(seed) return env def arg_parser(): """ Create an empty argparse.ArgumentParser. """ import argparse return argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) def atari_arg_parser(): """ Create an argparse.ArgumentParser for run_atari.py. """ parser = arg_parser() parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4') # parser.add_argument('--env', help='environment ID', default='PongNoFrameskip-v4') parser.add_argument('--seed', help='RNG seed', type=int, default=0) parser.add_argument('--num-timesteps', type=int, default=int(10e6)) return parser def mujoco_arg_parser(): """ Create an argparse.ArgumentParser for run_mujoco.py. """ parser = arg_parser() # parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2') parser.add_argument('--env', help='environment ID', type=str, default='Swimmer-v2') parser.add_argument('--seed', help='RNG seed', type=int, default=0) parser.add_argument('--num-timesteps', type=int, default=int(1e6)) parser.add_argument('--play', default=False, action='store_true') parser.add_argument('--force_write', default=0, type=int) return parser def robotics_arg_parser(): """ Create an argparse.ArgumentParser for run_mujoco.py. """ parser = arg_parser() parser.add_argument('--env', help='environment ID', type=str, default='FetchReach-v0') parser.add_argument('--seed', help='RNG seed', type=int, default=0) parser.add_argument('--num-timesteps', type=int, default=int(1e6)) return parser