From 7879c6cfe7b2f9b2085630f2eefdd71a1214bd0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BE=B7=E7=A5=A5?= Date: Tue, 13 Jun 2023 09:58:03 +0800 Subject: [PATCH] env v01 --- configs.yaml | 17 ++ dreamer.py | 12 + envs/memmaze.py | 639 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 668 insertions(+) create mode 100644 envs/memmaze.py diff --git a/configs.yaml b/configs.yaml index 332a977..6074feb 100644 --- a/configs.yaml +++ b/configs.yaml @@ -148,9 +148,26 @@ atari100k: imag_gradient: 'reinforce' time_limit: 108000 +mazed: + task: "memmaze_9_9" + steps: 5e4 + action_repeat: 2 + debug: debug: True pretrain: 1 prefill: 1 batch_size: 10 batch_length: 20 + +mazegym: + #task: "memory_maze:MemoryMaze-9x9-v0" + steps: 5e4 + action_repeat: 2 + +mazedeepm: + task: "memmaze_9_9" + steps: 5e4 + action_repeat: 2 + + diff --git a/dreamer.py b/dreamer.py index 36eb633..3b98259 100644 --- a/dreamer.py +++ b/dreamer.py @@ -210,6 +210,18 @@ def make_env(config, logger, mode, train_eps, eval_eps): task, mode if "train" in mode else "test", config.action_repeat ) env = wrappers.OneHotAction(env) + elif suite == "mazegym": + import gym + env = gym.make('memory_maze:MemoryMaze-9x9-v0') + from envs.memmaze import MZGymWrapper + env = MZGymWrapper(env) + + env = wrappers.OneHotAction(env) + elif suite == "---------mazed": + from memory_maze import tasks + ## !!!!!!!!!!!!!!!!!!!!!!!! + env = tasks.memory_maze_9x9() + env = wrappers.OneHotAction(env) else: raise NotImplementedError(suite) env = wrappers.TimeLimit(env, config.time_limit) diff --git a/envs/memmaze.py b/envs/memmaze.py new file mode 100644 index 0000000..9de34c4 --- /dev/null +++ b/envs/memmaze.py @@ -0,0 +1,639 @@ +import atexit +import os +import sys +import threading +import traceback + +import cloudpickle +import gym +import numpy as np + + +class GymWrapper: + + def __init__(self, env, obs_key='image', act_key='action'): + self._env = env + self._obs_is_dict = hasattr(self._env.observation_space, 'spaces') + self._act_is_dict = hasattr(self._env.action_space, 'spaces') + self._obs_key = obs_key + self._act_key = act_key + + def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError(name) + try: + return getattr(self._env, name) + except AttributeError: + raise ValueError(name) + + @property + def obs_space(self): + if self._obs_is_dict: + spaces = self._env.observation_space.spaces.copy() + else: + spaces = {self._obs_key: self._env.observation_space} + return { + **spaces, + 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), + 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), + } + + @property + def act_space(self): + if self._act_is_dict: + return self._env.action_space.spaces.copy() + else: + return {self._act_key: self._env.action_space} + + def step(self, action): + if not self._act_is_dict: + action = action[self._act_key] + obs, reward, done, info = self._env.step(action) + if not self._obs_is_dict: + obs = {self._obs_key: obs} + obs['reward'] = float(reward) + obs['is_first'] = False + obs['is_last'] = done + obs['is_terminal'] = info.get('is_terminal', done) + return obs + + def reset(self): + obs = self._env.reset() + if not self._obs_is_dict: + obs = {self._obs_key: obs} + obs['reward'] = 0.0 + obs['is_first'] = True + obs['is_last'] = False + obs['is_terminal'] = False + return obs + + +class DMC: + + def __init__(self, name, action_repeat=1, size=(64, 64), camera=None): + os.environ['MUJOCO_GL'] = 'egl' + domain, task = name.split('_', 1) + if domain == 'cup': # Only domain with multiple words. + domain = 'ball_in_cup' + if domain == 'manip': + from dm_control import manipulation + self._env = manipulation.load(task + '_vision') + elif domain == 'locom': + from dm_control.locomotion.examples import basic_rodent_2020 + self._env = getattr(basic_rodent_2020, task)() + else: + from dm_control import suite + self._env = suite.load(domain, task) + self._action_repeat = action_repeat + self._size = size + if camera in (-1, None): + camera = dict( + quadruped_walk=2, quadruped_run=2, quadruped_escape=2, + quadruped_fetch=2, locom_rodent_maze_forage=1, + locom_rodent_two_touch=1, + ).get(name, 0) + self._camera = camera + self._ignored_keys = [] + for key, value in self._env.observation_spec().items(): + if value.shape == (0,): + print(f"Ignoring empty observation key '{key}'.") + self._ignored_keys.append(key) + + @property + def obs_space(self): + spaces = { + 'image': gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8), + 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), + 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), + } + for key, value in self._env.observation_spec().items(): + if key in self._ignored_keys: + continue + if value.dtype == np.float64: + spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, np.float32) + elif value.dtype == np.uint8: + spaces[key] = gym.spaces.Box(0, 255, value.shape, np.uint8) + else: + raise NotImplementedError(value.dtype) + return spaces + + @property + def act_space(self): + spec = self._env.action_spec() + action = gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) + return {'action': action} + + def step(self, action): + assert np.isfinite(action['action']).all(), action['action'] + reward = 0.0 + for _ in range(self._action_repeat): + time_step = self._env.step(action['action']) + reward += time_step.reward or 0.0 + if time_step.last(): + break + assert time_step.discount in (0, 1) + obs = { + 'reward': reward, + 'is_first': False, + 'is_last': time_step.last(), + 'is_terminal': time_step.discount == 0, + 'image': self._env.physics.render(*self._size, camera_id=self._camera), + } + obs.update({ + k: v for k, v in dict(time_step.observation).items() + if k not in self._ignored_keys}) + return obs + + def reset(self): + time_step = self._env.reset() + obs = { + 'reward': 0.0, + 'is_first': True, + 'is_last': False, + 'is_terminal': False, + 'image': self._env.physics.render(*self._size, camera_id=self._camera), + } + obs.update({ + k: v for k, v in dict(time_step.observation).items() + if k not in self._ignored_keys}) + return obs + + +class Atari: + + LOCK = threading.Lock() + + def __init__( + self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30, + life_done=False, sticky=True, all_actions=False): + assert size[0] == size[1] + import gym.wrappers + import gym.envs.atari + if name == 'james_bond': + name = 'jamesbond' + with self.LOCK: + env = gym.envs.atari.AtariEnv( + game=name, obs_type='image', frameskip=1, + repeat_action_probability=0.25 if sticky else 0.0, + full_action_space=all_actions) + # Avoid unnecessary rendering in inner env. + env._get_obs = lambda: None + # Tell wrapper that the inner env has no action repeat. + env.spec = gym.envs.registration.EnvSpec('NoFrameskip-v0') + self._env = gym.wrappers.AtariPreprocessing( + env, noops, action_repeat, size[0], life_done, grayscale) + self._size = size + self._grayscale = grayscale + + @property + def obs_space(self): + shape = self._size + (1 if self._grayscale else 3,) + return { + 'image': gym.spaces.Box(0, 255, shape, np.uint8), + 'ram': gym.spaces.Box(0, 255, (128,), np.uint8), + 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), + 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), + } + + @property + def act_space(self): + return {'action': self._env.action_space} + + def step(self, action): + image, reward, done, info = self._env.step(action['action']) + if self._grayscale: + image = image[..., None] + return { + 'image': image, + 'ram': self._env.env._get_ram(), + 'reward': reward, + 'is_first': False, + 'is_last': done, + 'is_terminal': done, + } + + def reset(self): + with self.LOCK: + image = self._env.reset() + if self._grayscale: + image = image[..., None] + return { + 'image': image, + 'ram': self._env.env._get_ram(), + 'reward': 0.0, + 'is_first': True, + 'is_last': False, + 'is_terminal': False, + } + + def close(self): + return self._env.close() + + +class Crafter: + + def __init__(self, outdir=None, reward=True, seed=None): + import crafter + self._env = crafter.Env(reward=reward, seed=seed) + self._env = crafter.Recorder( + self._env, outdir, + save_stats=True, + save_video=False, + save_episode=False, + ) + self._achievements = crafter.constants.achievements.copy() + + @property + def obs_space(self): + spaces = { + 'image': self._env.observation_space, + 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), + 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'log_reward': gym.spaces.Box(-np.inf, np.inf, (), np.float32), + } + spaces.update({ + f'log_achievement_{k}': gym.spaces.Box(0, 2 ** 31 - 1, (), np.int32) + for k in self._achievements}) + return spaces + + @property + def act_space(self): + return {'action': self._env.action_space} + + def step(self, action): + image, reward, done, info = self._env.step(action['action']) + obs = { + 'image': image, + 'reward': reward, + 'is_first': False, + 'is_last': done, + 'is_terminal': info['discount'] == 0, + 'log_reward': info['reward'], + } + obs.update({ + f'log_achievement_{k}': v + for k, v in info['achievements'].items()}) + return obs + + def reset(self): + obs = { + 'image': self._env.reset(), + 'reward': 0.0, + 'is_first': True, + 'is_last': False, + 'is_terminal': False, + 'log_reward': 0.0, + } + obs.update({ + f'log_achievement_{k}': 0 + for k in self._achievements}) + return obs + + +class Dummy: + + def __init__(self): + pass + + @property + def obs_space(self): + return { + 'image': gym.spaces.Box(0, 255, (64, 64, 3), dtype=np.uint8), + 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), + 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), + } + + @property + def act_space(self): + return {'action': gym.spaces.Box(-1, 1, (6,), dtype=np.float32)} + + def step(self, action): + return { + 'image': np.zeros((64, 64, 3)), + 'reward': 0.0, + 'is_first': False, + 'is_last': False, + 'is_terminal': False, + } + + def reset(self): + return { + 'image': np.zeros((64, 64, 3)), + 'reward': 0.0, + 'is_first': True, + 'is_last': False, + 'is_terminal': False, + } + + +class TimeLimit: + + def __init__(self, env, duration): + self._env = env + self._duration = duration + self._step = None + + def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError(name) + try: + return getattr(self._env, name) + except AttributeError: + raise ValueError(name) + + def step(self, action): + assert self._step is not None, 'Must reset environment.' + obs = self._env.step(action) + self._step += 1 + if self._duration and self._step >= self._duration: + obs['is_last'] = True + self._step = None + return obs + + def reset(self): + self._step = 0 + return self._env.reset() + + +class NormalizeAction: + + def __init__(self, env, key='action'): + self._env = env + self._key = key + space = env.act_space[key] + self._mask = np.isfinite(space.low) & np.isfinite(space.high) + self._low = np.where(self._mask, space.low, -1) + self._high = np.where(self._mask, space.high, 1) + + def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError(name) + try: + return getattr(self._env, name) + except AttributeError: + raise ValueError(name) + + @property + def act_space(self): + low = np.where(self._mask, -np.ones_like(self._low), self._low) + high = np.where(self._mask, np.ones_like(self._low), self._high) + space = gym.spaces.Box(low, high, dtype=np.float32) + return {**self._env.act_space, self._key: space} + + def step(self, action): + orig = (action[self._key] + 1) / 2 * (self._high - self._low) + self._low + orig = np.where(self._mask, orig, action[self._key]) + return self._env.step({**action, self._key: orig}) + + +class OneHotAction: + + def __init__(self, env, key='action'): + assert hasattr(env.act_space[key], 'n') + self._env = env + self._key = key + self._random = np.random.RandomState() + + def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError(name) + try: + return getattr(self._env, name) + except AttributeError: + raise ValueError(name) + + @property + def act_space(self): + shape = (self._env.act_space[self._key].n,) + space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) + space.sample = self._sample_action + space.n = shape[0] + return {**self._env.act_space, self._key: space} + + def step(self, action): + index = np.argmax(action[self._key]).astype(int) + reference = np.zeros_like(action[self._key]) + reference[index] = 1 + if not np.allclose(reference, action[self._key]): + raise ValueError(f'Invalid one-hot action:\n{action}') + return self._env.step({**action, self._key: index}) + + def reset(self): + return self._env.reset() + + def _sample_action(self): + actions = self._env.act_space.n + index = self._random.randint(0, actions) + reference = np.zeros(actions, dtype=np.float32) + reference[index] = 1.0 + return reference + + +class ResizeImage: + + def __init__(self, env, size=(64, 64)): + self._env = env + self._size = size + self._keys = [ + k for k, v in env.obs_space.items() + if len(v.shape) > 1 and v.shape[:2] != size] + print(f'Resizing keys {",".join(self._keys)} to {self._size}.') + if self._keys: + from PIL import Image + self._Image = Image + + def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError(name) + try: + return getattr(self._env, name) + except AttributeError: + raise ValueError(name) + + @property + def obs_space(self): + spaces = self._env.obs_space + for key in self._keys: + shape = self._size + spaces[key].shape[2:] + spaces[key] = gym.spaces.Box(0, 255, shape, np.uint8) + return spaces + + def step(self, action): + obs = self._env.step(action) + for key in self._keys: + obs[key] = self._resize(obs[key]) + return obs + + def reset(self): + obs = self._env.reset() + for key in self._keys: + obs[key] = self._resize(obs[key]) + return obs + + def _resize(self, image): + image = self._Image.fromarray(image) + image = image.resize(self._size, self._Image.NEAREST) + image = np.array(image) + return image + + +class RenderImage: + + def __init__(self, env, key='image'): + self._env = env + self._key = key + self._shape = self._env.render().shape + + def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError(name) + try: + return getattr(self._env, name) + except AttributeError: + raise ValueError(name) + + @property + def obs_space(self): + spaces = self._env.obs_space + spaces[self._key] = gym.spaces.Box(0, 255, self._shape, np.uint8) + return spaces + + def step(self, action): + obs = self._env.step(action) + obs[self._key] = self._env.render('rgb_array') + return obs + + def reset(self): + obs = self._env.reset() + obs[self._key] = self._env.render('rgb_array') + return obs + + +class Async: + + # Message types for communication via the pipe. + _ACCESS = 1 + _CALL = 2 + _RESULT = 3 + _CLOSE = 4 + _EXCEPTION = 5 + + def __init__(self, constructor, strategy='thread'): + self._pickled_ctor = cloudpickle.dumps(constructor) + if strategy == 'process': + import multiprocessing as mp + context = mp.get_context('spawn') + elif strategy == 'thread': + import multiprocessing.dummy as context + else: + raise NotImplementedError(strategy) + self._strategy = strategy + self._conn, conn = context.Pipe() + self._process = context.Process(target=self._worker, args=(conn,)) + atexit.register(self.close) + self._process.start() + self._receive() # Ready. + self._obs_space = None + self._act_space = None + + def access(self, name): + self._conn.send((self._ACCESS, name)) + return self._receive + + def call(self, name, *args, **kwargs): + payload = name, args, kwargs + self._conn.send((self._CALL, payload)) + return self._receive + + def close(self): + try: + self._conn.send((self._CLOSE, None)) + self._conn.close() + except IOError: + pass # The connection was already closed. + self._process.join(5) + + @property + def obs_space(self): + if not self._obs_space: + self._obs_space = self.access('obs_space')() + return self._obs_space + + @property + def act_space(self): + if not self._act_space: + self._act_space = self.access('act_space')() + return self._act_space + + def step(self, action, blocking=False): + promise = self.call('step', action) + if blocking: + return promise() + else: + return promise + + def reset(self, blocking=False): + promise = self.call('reset') + if blocking: + return promise() + else: + return promise + + def _receive(self): + try: + message, payload = self._conn.recv() + except (OSError, EOFError): + raise RuntimeError('Lost connection to environment worker.') + # Re-raise exceptions in the main process. + if message == self._EXCEPTION: + stacktrace = payload + raise Exception(stacktrace) + if message == self._RESULT: + return payload + raise KeyError('Received message of unexpected type {}'.format(message)) + + def _worker(self, conn): + try: + ctor = cloudpickle.loads(self._pickled_ctor) + env = ctor() + conn.send((self._RESULT, None)) # Ready. + while True: + try: + # Only block for short times to have keyboard exceptions be raised. + if not conn.poll(0.1): + continue + message, payload = conn.recv() + except (EOFError, KeyboardInterrupt): + break + if message == self._ACCESS: + name = payload + result = getattr(env, name) + conn.send((self._RESULT, result)) + continue + if message == self._CALL: + name, args, kwargs = payload + result = getattr(env, name)(*args, **kwargs) + conn.send((self._RESULT, result)) + continue + if message == self._CLOSE: + break + raise KeyError('Received message of unknown type {}'.format(message)) + except Exception: + stacktrace = ''.join(traceback.format_exception(*sys.exc_info())) + print('Error in environment process: {}'.format(stacktrace)) + conn.send((self._EXCEPTION, stacktrace)) + finally: + try: + conn.close() + except IOError: + pass # The connection was already closed.