From 7879c6cfe7b2f9b2085630f2eefdd71a1214bd0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BE=B7=E7=A5=A5?= Date: Tue, 13 Jun 2023 09:58:03 +0800 Subject: [PATCH 1/7] env v01 --- configs.yaml | 17 ++ dreamer.py | 12 + envs/memmaze.py | 639 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 668 insertions(+) create mode 100644 envs/memmaze.py diff --git a/configs.yaml b/configs.yaml index 332a977..6074feb 100644 --- a/configs.yaml +++ b/configs.yaml @@ -148,9 +148,26 @@ atari100k: imag_gradient: 'reinforce' time_limit: 108000 +mazed: + task: "memmaze_9_9" + steps: 5e4 + action_repeat: 2 + debug: debug: True pretrain: 1 prefill: 1 batch_size: 10 batch_length: 20 + +mazegym: + #task: "memory_maze:MemoryMaze-9x9-v0" + steps: 5e4 + action_repeat: 2 + +mazedeepm: + task: "memmaze_9_9" + steps: 5e4 + action_repeat: 2 + + diff --git a/dreamer.py b/dreamer.py index 36eb633..3b98259 100644 --- a/dreamer.py +++ b/dreamer.py @@ -210,6 +210,18 @@ def make_env(config, logger, mode, train_eps, eval_eps): task, mode if "train" in mode else "test", config.action_repeat ) env = wrappers.OneHotAction(env) + elif suite == "mazegym": + import gym + env = gym.make('memory_maze:MemoryMaze-9x9-v0') + from envs.memmaze import MZGymWrapper + env = MZGymWrapper(env) + + env = wrappers.OneHotAction(env) + elif suite == "---------mazed": + from memory_maze import tasks + ## !!!!!!!!!!!!!!!!!!!!!!!! + env = tasks.memory_maze_9x9() + env = wrappers.OneHotAction(env) else: raise NotImplementedError(suite) env = wrappers.TimeLimit(env, config.time_limit) diff --git a/envs/memmaze.py b/envs/memmaze.py new file mode 100644 index 0000000..9de34c4 --- /dev/null +++ b/envs/memmaze.py @@ -0,0 +1,639 @@ +import atexit +import os +import sys +import threading +import traceback + +import cloudpickle +import gym +import numpy as np + + +class GymWrapper: + + def __init__(self, env, obs_key='image', act_key='action'): + self._env = env + self._obs_is_dict = hasattr(self._env.observation_space, 'spaces') + self._act_is_dict = hasattr(self._env.action_space, 'spaces') + self._obs_key = obs_key + self._act_key = act_key + + def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError(name) + try: + return getattr(self._env, name) + except AttributeError: + raise ValueError(name) + + @property + def obs_space(self): + if self._obs_is_dict: + spaces = self._env.observation_space.spaces.copy() + else: + spaces = {self._obs_key: self._env.observation_space} + return { + **spaces, + 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), + 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), + } + + @property + def act_space(self): + if self._act_is_dict: + return self._env.action_space.spaces.copy() + else: + return {self._act_key: self._env.action_space} + + def step(self, action): + if not self._act_is_dict: + action = action[self._act_key] + obs, reward, done, info = self._env.step(action) + if not self._obs_is_dict: + obs = {self._obs_key: obs} + obs['reward'] = float(reward) + obs['is_first'] = False + obs['is_last'] = done + obs['is_terminal'] = info.get('is_terminal', done) + return obs + + def reset(self): + obs = self._env.reset() + if not self._obs_is_dict: + obs = {self._obs_key: obs} + obs['reward'] = 0.0 + obs['is_first'] = True + obs['is_last'] = False + obs['is_terminal'] = False + return obs + + +class DMC: + + def __init__(self, name, action_repeat=1, size=(64, 64), camera=None): + os.environ['MUJOCO_GL'] = 'egl' + domain, task = name.split('_', 1) + if domain == 'cup': # Only domain with multiple words. + domain = 'ball_in_cup' + if domain == 'manip': + from dm_control import manipulation + self._env = manipulation.load(task + '_vision') + elif domain == 'locom': + from dm_control.locomotion.examples import basic_rodent_2020 + self._env = getattr(basic_rodent_2020, task)() + else: + from dm_control import suite + self._env = suite.load(domain, task) + self._action_repeat = action_repeat + self._size = size + if camera in (-1, None): + camera = dict( + quadruped_walk=2, quadruped_run=2, quadruped_escape=2, + quadruped_fetch=2, locom_rodent_maze_forage=1, + locom_rodent_two_touch=1, + ).get(name, 0) + self._camera = camera + self._ignored_keys = [] + for key, value in self._env.observation_spec().items(): + if value.shape == (0,): + print(f"Ignoring empty observation key '{key}'.") + self._ignored_keys.append(key) + + @property + def obs_space(self): + spaces = { + 'image': gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8), + 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), + 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), + } + for key, value in self._env.observation_spec().items(): + if key in self._ignored_keys: + continue + if value.dtype == np.float64: + spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, np.float32) + elif value.dtype == np.uint8: + spaces[key] = gym.spaces.Box(0, 255, value.shape, np.uint8) + else: + raise NotImplementedError(value.dtype) + return spaces + + @property + def act_space(self): + spec = self._env.action_spec() + action = gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) + return {'action': action} + + def step(self, action): + assert np.isfinite(action['action']).all(), action['action'] + reward = 0.0 + for _ in range(self._action_repeat): + time_step = self._env.step(action['action']) + reward += time_step.reward or 0.0 + if time_step.last(): + break + assert time_step.discount in (0, 1) + obs = { + 'reward': reward, + 'is_first': False, + 'is_last': time_step.last(), + 'is_terminal': time_step.discount == 0, + 'image': self._env.physics.render(*self._size, camera_id=self._camera), + } + obs.update({ + k: v for k, v in dict(time_step.observation).items() + if k not in self._ignored_keys}) + return obs + + def reset(self): + time_step = self._env.reset() + obs = { + 'reward': 0.0, + 'is_first': True, + 'is_last': False, + 'is_terminal': False, + 'image': self._env.physics.render(*self._size, camera_id=self._camera), + } + obs.update({ + k: v for k, v in dict(time_step.observation).items() + if k not in self._ignored_keys}) + return obs + + +class Atari: + + LOCK = threading.Lock() + + def __init__( + self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30, + life_done=False, sticky=True, all_actions=False): + assert size[0] == size[1] + import gym.wrappers + import gym.envs.atari + if name == 'james_bond': + name = 'jamesbond' + with self.LOCK: + env = gym.envs.atari.AtariEnv( + game=name, obs_type='image', frameskip=1, + repeat_action_probability=0.25 if sticky else 0.0, + full_action_space=all_actions) + # Avoid unnecessary rendering in inner env. + env._get_obs = lambda: None + # Tell wrapper that the inner env has no action repeat. + env.spec = gym.envs.registration.EnvSpec('NoFrameskip-v0') + self._env = gym.wrappers.AtariPreprocessing( + env, noops, action_repeat, size[0], life_done, grayscale) + self._size = size + self._grayscale = grayscale + + @property + def obs_space(self): + shape = self._size + (1 if self._grayscale else 3,) + return { + 'image': gym.spaces.Box(0, 255, shape, np.uint8), + 'ram': gym.spaces.Box(0, 255, (128,), np.uint8), + 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), + 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), + } + + @property + def act_space(self): + return {'action': self._env.action_space} + + def step(self, action): + image, reward, done, info = self._env.step(action['action']) + if self._grayscale: + image = image[..., None] + return { + 'image': image, + 'ram': self._env.env._get_ram(), + 'reward': reward, + 'is_first': False, + 'is_last': done, + 'is_terminal': done, + } + + def reset(self): + with self.LOCK: + image = self._env.reset() + if self._grayscale: + image = image[..., None] + return { + 'image': image, + 'ram': self._env.env._get_ram(), + 'reward': 0.0, + 'is_first': True, + 'is_last': False, + 'is_terminal': False, + } + + def close(self): + return self._env.close() + + +class Crafter: + + def __init__(self, outdir=None, reward=True, seed=None): + import crafter + self._env = crafter.Env(reward=reward, seed=seed) + self._env = crafter.Recorder( + self._env, outdir, + save_stats=True, + save_video=False, + save_episode=False, + ) + self._achievements = crafter.constants.achievements.copy() + + @property + def obs_space(self): + spaces = { + 'image': self._env.observation_space, + 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), + 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'log_reward': gym.spaces.Box(-np.inf, np.inf, (), np.float32), + } + spaces.update({ + f'log_achievement_{k}': gym.spaces.Box(0, 2 ** 31 - 1, (), np.int32) + for k in self._achievements}) + return spaces + + @property + def act_space(self): + return {'action': self._env.action_space} + + def step(self, action): + image, reward, done, info = self._env.step(action['action']) + obs = { + 'image': image, + 'reward': reward, + 'is_first': False, + 'is_last': done, + 'is_terminal': info['discount'] == 0, + 'log_reward': info['reward'], + } + obs.update({ + f'log_achievement_{k}': v + for k, v in info['achievements'].items()}) + return obs + + def reset(self): + obs = { + 'image': self._env.reset(), + 'reward': 0.0, + 'is_first': True, + 'is_last': False, + 'is_terminal': False, + 'log_reward': 0.0, + } + obs.update({ + f'log_achievement_{k}': 0 + for k in self._achievements}) + return obs + + +class Dummy: + + def __init__(self): + pass + + @property + def obs_space(self): + return { + 'image': gym.spaces.Box(0, 255, (64, 64, 3), dtype=np.uint8), + 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), + 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), + } + + @property + def act_space(self): + return {'action': gym.spaces.Box(-1, 1, (6,), dtype=np.float32)} + + def step(self, action): + return { + 'image': np.zeros((64, 64, 3)), + 'reward': 0.0, + 'is_first': False, + 'is_last': False, + 'is_terminal': False, + } + + def reset(self): + return { + 'image': np.zeros((64, 64, 3)), + 'reward': 0.0, + 'is_first': True, + 'is_last': False, + 'is_terminal': False, + } + + +class TimeLimit: + + def __init__(self, env, duration): + self._env = env + self._duration = duration + self._step = None + + def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError(name) + try: + return getattr(self._env, name) + except AttributeError: + raise ValueError(name) + + def step(self, action): + assert self._step is not None, 'Must reset environment.' + obs = self._env.step(action) + self._step += 1 + if self._duration and self._step >= self._duration: + obs['is_last'] = True + self._step = None + return obs + + def reset(self): + self._step = 0 + return self._env.reset() + + +class NormalizeAction: + + def __init__(self, env, key='action'): + self._env = env + self._key = key + space = env.act_space[key] + self._mask = np.isfinite(space.low) & np.isfinite(space.high) + self._low = np.where(self._mask, space.low, -1) + self._high = np.where(self._mask, space.high, 1) + + def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError(name) + try: + return getattr(self._env, name) + except AttributeError: + raise ValueError(name) + + @property + def act_space(self): + low = np.where(self._mask, -np.ones_like(self._low), self._low) + high = np.where(self._mask, np.ones_like(self._low), self._high) + space = gym.spaces.Box(low, high, dtype=np.float32) + return {**self._env.act_space, self._key: space} + + def step(self, action): + orig = (action[self._key] + 1) / 2 * (self._high - self._low) + self._low + orig = np.where(self._mask, orig, action[self._key]) + return self._env.step({**action, self._key: orig}) + + +class OneHotAction: + + def __init__(self, env, key='action'): + assert hasattr(env.act_space[key], 'n') + self._env = env + self._key = key + self._random = np.random.RandomState() + + def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError(name) + try: + return getattr(self._env, name) + except AttributeError: + raise ValueError(name) + + @property + def act_space(self): + shape = (self._env.act_space[self._key].n,) + space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) + space.sample = self._sample_action + space.n = shape[0] + return {**self._env.act_space, self._key: space} + + def step(self, action): + index = np.argmax(action[self._key]).astype(int) + reference = np.zeros_like(action[self._key]) + reference[index] = 1 + if not np.allclose(reference, action[self._key]): + raise ValueError(f'Invalid one-hot action:\n{action}') + return self._env.step({**action, self._key: index}) + + def reset(self): + return self._env.reset() + + def _sample_action(self): + actions = self._env.act_space.n + index = self._random.randint(0, actions) + reference = np.zeros(actions, dtype=np.float32) + reference[index] = 1.0 + return reference + + +class ResizeImage: + + def __init__(self, env, size=(64, 64)): + self._env = env + self._size = size + self._keys = [ + k for k, v in env.obs_space.items() + if len(v.shape) > 1 and v.shape[:2] != size] + print(f'Resizing keys {",".join(self._keys)} to {self._size}.') + if self._keys: + from PIL import Image + self._Image = Image + + def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError(name) + try: + return getattr(self._env, name) + except AttributeError: + raise ValueError(name) + + @property + def obs_space(self): + spaces = self._env.obs_space + for key in self._keys: + shape = self._size + spaces[key].shape[2:] + spaces[key] = gym.spaces.Box(0, 255, shape, np.uint8) + return spaces + + def step(self, action): + obs = self._env.step(action) + for key in self._keys: + obs[key] = self._resize(obs[key]) + return obs + + def reset(self): + obs = self._env.reset() + for key in self._keys: + obs[key] = self._resize(obs[key]) + return obs + + def _resize(self, image): + image = self._Image.fromarray(image) + image = image.resize(self._size, self._Image.NEAREST) + image = np.array(image) + return image + + +class RenderImage: + + def __init__(self, env, key='image'): + self._env = env + self._key = key + self._shape = self._env.render().shape + + def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError(name) + try: + return getattr(self._env, name) + except AttributeError: + raise ValueError(name) + + @property + def obs_space(self): + spaces = self._env.obs_space + spaces[self._key] = gym.spaces.Box(0, 255, self._shape, np.uint8) + return spaces + + def step(self, action): + obs = self._env.step(action) + obs[self._key] = self._env.render('rgb_array') + return obs + + def reset(self): + obs = self._env.reset() + obs[self._key] = self._env.render('rgb_array') + return obs + + +class Async: + + # Message types for communication via the pipe. + _ACCESS = 1 + _CALL = 2 + _RESULT = 3 + _CLOSE = 4 + _EXCEPTION = 5 + + def __init__(self, constructor, strategy='thread'): + self._pickled_ctor = cloudpickle.dumps(constructor) + if strategy == 'process': + import multiprocessing as mp + context = mp.get_context('spawn') + elif strategy == 'thread': + import multiprocessing.dummy as context + else: + raise NotImplementedError(strategy) + self._strategy = strategy + self._conn, conn = context.Pipe() + self._process = context.Process(target=self._worker, args=(conn,)) + atexit.register(self.close) + self._process.start() + self._receive() # Ready. + self._obs_space = None + self._act_space = None + + def access(self, name): + self._conn.send((self._ACCESS, name)) + return self._receive + + def call(self, name, *args, **kwargs): + payload = name, args, kwargs + self._conn.send((self._CALL, payload)) + return self._receive + + def close(self): + try: + self._conn.send((self._CLOSE, None)) + self._conn.close() + except IOError: + pass # The connection was already closed. + self._process.join(5) + + @property + def obs_space(self): + if not self._obs_space: + self._obs_space = self.access('obs_space')() + return self._obs_space + + @property + def act_space(self): + if not self._act_space: + self._act_space = self.access('act_space')() + return self._act_space + + def step(self, action, blocking=False): + promise = self.call('step', action) + if blocking: + return promise() + else: + return promise + + def reset(self, blocking=False): + promise = self.call('reset') + if blocking: + return promise() + else: + return promise + + def _receive(self): + try: + message, payload = self._conn.recv() + except (OSError, EOFError): + raise RuntimeError('Lost connection to environment worker.') + # Re-raise exceptions in the main process. + if message == self._EXCEPTION: + stacktrace = payload + raise Exception(stacktrace) + if message == self._RESULT: + return payload + raise KeyError('Received message of unexpected type {}'.format(message)) + + def _worker(self, conn): + try: + ctor = cloudpickle.loads(self._pickled_ctor) + env = ctor() + conn.send((self._RESULT, None)) # Ready. + while True: + try: + # Only block for short times to have keyboard exceptions be raised. + if not conn.poll(0.1): + continue + message, payload = conn.recv() + except (EOFError, KeyboardInterrupt): + break + if message == self._ACCESS: + name = payload + result = getattr(env, name) + conn.send((self._RESULT, result)) + continue + if message == self._CALL: + name, args, kwargs = payload + result = getattr(env, name)(*args, **kwargs) + conn.send((self._RESULT, result)) + continue + if message == self._CLOSE: + break + raise KeyError('Received message of unknown type {}'.format(message)) + except Exception: + stacktrace = ''.join(traceback.format_exception(*sys.exc_info())) + print('Error in environment process: {}'.format(stacktrace)) + conn.send((self._EXCEPTION, stacktrace)) + finally: + try: + conn.close() + except IOError: + pass # The connection was already closed. From 5038a91aad6d7bc885cc73dff8df9d7ae40a0801 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BE=B7=E7=A5=A5?= Date: Tue, 13 Jun 2023 10:44:54 +0800 Subject: [PATCH 2/7] env v0.11 --- dreamer.py | 2 +- envs/{memmaze.py => memmazeEnv.py} | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename envs/{memmaze.py => memmazeEnv.py} (99%) diff --git a/dreamer.py b/dreamer.py index 3b98259..61ae538 100644 --- a/dreamer.py +++ b/dreamer.py @@ -213,7 +213,7 @@ def make_env(config, logger, mode, train_eps, eval_eps): elif suite == "mazegym": import gym env = gym.make('memory_maze:MemoryMaze-9x9-v0') - from envs.memmaze import MZGymWrapper + from envs.memmazeEnv import MZGymWrapper env = MZGymWrapper(env) env = wrappers.OneHotAction(env) diff --git a/envs/memmaze.py b/envs/memmazeEnv.py similarity index 99% rename from envs/memmaze.py rename to envs/memmazeEnv.py index 9de34c4..baabf90 100644 --- a/envs/memmaze.py +++ b/envs/memmazeEnv.py @@ -9,7 +9,7 @@ import gym import numpy as np -class GymWrapper: +class MZGymWrapper: def __init__(self, env, obs_key='image', act_key='action'): self._env = env From b9120a744087882507d23dda59532883d9bf032a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BE=B7=E7=A5=A5?= Date: Tue, 13 Jun 2023 21:39:04 +0800 Subject: [PATCH 3/7] env v0.12 --- configs.yaml | 2 +- dreamer.py | 2 +- envs/memmazeEnv.py | 23 +++++++++++++++++++---- envs/wrappers.py | 5 +++-- 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/configs.yaml b/configs.yaml index 6074feb..7c246e1 100644 --- a/configs.yaml +++ b/configs.yaml @@ -25,7 +25,7 @@ defaults: action_repeat: 2 time_limit: 1000 grayscale: False - prefill: 2500 + prefill: 250 #0 eval_noise: 0.0 reward_EMA: True diff --git a/dreamer.py b/dreamer.py index 61ae538..04b46ab 100644 --- a/dreamer.py +++ b/dreamer.py @@ -215,7 +215,7 @@ def make_env(config, logger, mode, train_eps, eval_eps): env = gym.make('memory_maze:MemoryMaze-9x9-v0') from envs.memmazeEnv import MZGymWrapper env = MZGymWrapper(env) - + #from envs.memmazeEnv import OneHotAction as OneHotAction2 env = wrappers.OneHotAction(env) elif suite == "---------mazed": from memory_maze import tasks diff --git a/envs/memmazeEnv.py b/envs/memmazeEnv.py index baabf90..37a835f 100644 --- a/envs/memmazeEnv.py +++ b/envs/memmazeEnv.py @@ -47,17 +47,32 @@ class MZGymWrapper: else: return {self._act_key: self._env.action_space} + @property + def observation_space(self): + img_shape = self._size + ((1,) if self._gray else (3,)) + return gym.spaces.Dict( + { + "image": gym.spaces.Box(0, 255, img_shape, np.uint8), + } + ) + + @property + def action_space(self): + space = self._env.action_space + space.discrete = True + return space + def step(self, action): - if not self._act_is_dict: - action = action[self._act_key] + # if not self._act_is_dict: + # action = action[self._act_key] obs, reward, done, info = self._env.step(action) if not self._obs_is_dict: obs = {self._obs_key: obs} - obs['reward'] = float(reward) + # obs['reward'] = float(reward) obs['is_first'] = False obs['is_last'] = done obs['is_terminal'] = info.get('is_terminal', done) - return obs + return obs, reward, done, info def reset(self): obs = self._env.reset() diff --git a/envs/wrappers.py b/envs/wrappers.py index 1a4a58b..9769fc9 100644 --- a/envs/wrappers.py +++ b/envs/wrappers.py @@ -77,8 +77,8 @@ class CollectDataset: dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision] elif np.issubdtype(value.dtype, np.uint8): dtype = np.uint8 - elif np.issubdtype(value.dtype, np.bool): - dtype = np.bool + elif np.issubdtype(value.dtype, np.bool_): + dtype = np.bool_ else: raise NotImplementedError(value.dtype) return value.astype(dtype) @@ -96,6 +96,7 @@ class TimeLimit: def step(self, action): assert self._step is not None, "Must reset environment." obs, reward, done, info = self._env.step(action) + # teets = self._env.step(action) self._step += 1 if self._step >= self._duration: done = True From 1cf0149c10a63c88765bc8b397e6932f1bedfa96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BE=B7=E7=A5=A5?= Date: Wed, 14 Jun 2023 20:22:17 +0800 Subject: [PATCH 4/7] env v0.13 --- configs.yaml | 8 ++++---- envs/memmazeEnv.py | 4 +++- envs/wrappers.py | 4 ++-- tools.py | 2 +- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/configs.yaml b/configs.yaml index 7c246e1..65cf26e 100644 --- a/configs.yaml +++ b/configs.yaml @@ -25,7 +25,7 @@ defaults: action_repeat: 2 time_limit: 1000 grayscale: False - prefill: 250 #0 + prefill: 3500 eval_noise: 0.0 reward_EMA: True @@ -150,7 +150,7 @@ atari100k: mazed: task: "memmaze_9_9" - steps: 5e4 + steps: 1e6 action_repeat: 2 debug: @@ -162,12 +162,12 @@ debug: mazegym: #task: "memory_maze:MemoryMaze-9x9-v0" - steps: 5e4 + steps: 1e6 action_repeat: 2 mazedeepm: task: "memmaze_9_9" - steps: 5e4 + steps: 1e6 action_repeat: 2 diff --git a/envs/memmazeEnv.py b/envs/memmazeEnv.py index 37a835f..81914fb 100644 --- a/envs/memmazeEnv.py +++ b/envs/memmazeEnv.py @@ -11,12 +11,14 @@ import numpy as np class MZGymWrapper: - def __init__(self, env, obs_key='image', act_key='action'): + def __init__(self, env, obs_key='image', act_key='action', size=(64, 64)): self._env = env self._obs_is_dict = hasattr(self._env.observation_space, 'spaces') self._act_is_dict = hasattr(self._env.action_space, 'spaces') self._obs_key = obs_key self._act_key = act_key + self._size = size + self._gray = False def __getattr__(self, name): if name.startswith('__'): diff --git a/envs/wrappers.py b/envs/wrappers.py index 9769fc9..eacb173 100644 --- a/envs/wrappers.py +++ b/envs/wrappers.py @@ -155,8 +155,8 @@ class OneHotAction: index = np.argmax(action).astype(int) reference = np.zeros_like(action) reference[index] = 1 - if not np.allclose(reference, action): - raise ValueError(f"Invalid one-hot action:\n{action}") + # if not np.allclose(reference, action): + # raise ValueError(f"Invalid one-hot action:\n{action}") return self._env.step(index) def reset(self): diff --git a/tools.py b/tools.py index bc46903..752b786 100644 --- a/tools.py +++ b/tools.py @@ -127,7 +127,7 @@ def simulate(agent, envs, steps=0, episodes=0, state=None): # Initialize or unpack simulation state. if state is None: step, episode = 0, 0 - done = np.ones(len(envs), np.bool) + done = np.ones(len(envs), np.bool_) length = np.zeros(len(envs), np.int32) obs = [None] * len(envs) agent_state = None From ea446adaf46bc4268c1ebfce5c9aeedde490aff3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BE=B7=E7=A5=A5?= Date: Sat, 17 Jun 2023 23:29:53 +0800 Subject: [PATCH 5/7] mem maze env ok 1 --- configs.yaml | 15 +- dreamer.py | 15 +- envs/memmazeEnv.py | 569 +-------------------------------------------- envs/wrappers.py | 37 ++- 4 files changed, 48 insertions(+), 588 deletions(-) diff --git a/configs.yaml b/configs.yaml index 65cf26e..4db75dc 100644 --- a/configs.yaml +++ b/configs.yaml @@ -25,7 +25,7 @@ defaults: action_repeat: 2 time_limit: 1000 grayscale: False - prefill: 3500 + prefill: 2500 eval_noise: 0.0 reward_EMA: True @@ -148,10 +148,6 @@ atari100k: imag_gradient: 'reinforce' time_limit: 108000 -mazed: - task: "memmaze_9_9" - steps: 1e6 - action_repeat: 2 debug: debug: True @@ -161,13 +157,10 @@ debug: batch_length: 20 mazegym: - #task: "memory_maze:MemoryMaze-9x9-v0" - steps: 1e6 - action_repeat: 2 - -mazedeepm: - task: "memmaze_9_9" + task: 9 steps: 1e6 action_repeat: 2 + + diff --git a/dreamer.py b/dreamer.py index 04b46ab..c20e8ca 100644 --- a/dreamer.py +++ b/dreamer.py @@ -212,16 +212,17 @@ def make_env(config, logger, mode, train_eps, eval_eps): env = wrappers.OneHotAction(env) elif suite == "mazegym": import gym - env = gym.make('memory_maze:MemoryMaze-9x9-v0') + if task == 9: + env = gym.make('memory_maze:MemoryMaze-9x9-v0') + elif task == 15: + env = gym.make('memory_maze:MemoryMaze-15x15-v0') + else: + raise NotImplementedError(suite) from envs.memmazeEnv import MZGymWrapper env = MZGymWrapper(env) #from envs.memmazeEnv import OneHotAction as OneHotAction2 - env = wrappers.OneHotAction(env) - elif suite == "---------mazed": - from memory_maze import tasks - ## !!!!!!!!!!!!!!!!!!!!!!!! - env = tasks.memory_maze_9x9() - env = wrappers.OneHotAction(env) + env = wrappers.OneHotAction2(env) + else: raise NotImplementedError(suite) env = wrappers.TimeLimit(env, config.time_limit) diff --git a/envs/memmazeEnv.py b/envs/memmazeEnv.py index 81914fb..980f805 100644 --- a/envs/memmazeEnv.py +++ b/envs/memmazeEnv.py @@ -8,6 +8,7 @@ import cloudpickle import gym import numpy as np +###from tf dreamerv2 code class MZGymWrapper: @@ -86,571 +87,3 @@ class MZGymWrapper: obs['is_terminal'] = False return obs - -class DMC: - - def __init__(self, name, action_repeat=1, size=(64, 64), camera=None): - os.environ['MUJOCO_GL'] = 'egl' - domain, task = name.split('_', 1) - if domain == 'cup': # Only domain with multiple words. - domain = 'ball_in_cup' - if domain == 'manip': - from dm_control import manipulation - self._env = manipulation.load(task + '_vision') - elif domain == 'locom': - from dm_control.locomotion.examples import basic_rodent_2020 - self._env = getattr(basic_rodent_2020, task)() - else: - from dm_control import suite - self._env = suite.load(domain, task) - self._action_repeat = action_repeat - self._size = size - if camera in (-1, None): - camera = dict( - quadruped_walk=2, quadruped_run=2, quadruped_escape=2, - quadruped_fetch=2, locom_rodent_maze_forage=1, - locom_rodent_two_touch=1, - ).get(name, 0) - self._camera = camera - self._ignored_keys = [] - for key, value in self._env.observation_spec().items(): - if value.shape == (0,): - print(f"Ignoring empty observation key '{key}'.") - self._ignored_keys.append(key) - - @property - def obs_space(self): - spaces = { - 'image': gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8), - 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), - 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), - } - for key, value in self._env.observation_spec().items(): - if key in self._ignored_keys: - continue - if value.dtype == np.float64: - spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, np.float32) - elif value.dtype == np.uint8: - spaces[key] = gym.spaces.Box(0, 255, value.shape, np.uint8) - else: - raise NotImplementedError(value.dtype) - return spaces - - @property - def act_space(self): - spec = self._env.action_spec() - action = gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) - return {'action': action} - - def step(self, action): - assert np.isfinite(action['action']).all(), action['action'] - reward = 0.0 - for _ in range(self._action_repeat): - time_step = self._env.step(action['action']) - reward += time_step.reward or 0.0 - if time_step.last(): - break - assert time_step.discount in (0, 1) - obs = { - 'reward': reward, - 'is_first': False, - 'is_last': time_step.last(), - 'is_terminal': time_step.discount == 0, - 'image': self._env.physics.render(*self._size, camera_id=self._camera), - } - obs.update({ - k: v for k, v in dict(time_step.observation).items() - if k not in self._ignored_keys}) - return obs - - def reset(self): - time_step = self._env.reset() - obs = { - 'reward': 0.0, - 'is_first': True, - 'is_last': False, - 'is_terminal': False, - 'image': self._env.physics.render(*self._size, camera_id=self._camera), - } - obs.update({ - k: v for k, v in dict(time_step.observation).items() - if k not in self._ignored_keys}) - return obs - - -class Atari: - - LOCK = threading.Lock() - - def __init__( - self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30, - life_done=False, sticky=True, all_actions=False): - assert size[0] == size[1] - import gym.wrappers - import gym.envs.atari - if name == 'james_bond': - name = 'jamesbond' - with self.LOCK: - env = gym.envs.atari.AtariEnv( - game=name, obs_type='image', frameskip=1, - repeat_action_probability=0.25 if sticky else 0.0, - full_action_space=all_actions) - # Avoid unnecessary rendering in inner env. - env._get_obs = lambda: None - # Tell wrapper that the inner env has no action repeat. - env.spec = gym.envs.registration.EnvSpec('NoFrameskip-v0') - self._env = gym.wrappers.AtariPreprocessing( - env, noops, action_repeat, size[0], life_done, grayscale) - self._size = size - self._grayscale = grayscale - - @property - def obs_space(self): - shape = self._size + (1 if self._grayscale else 3,) - return { - 'image': gym.spaces.Box(0, 255, shape, np.uint8), - 'ram': gym.spaces.Box(0, 255, (128,), np.uint8), - 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), - 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), - } - - @property - def act_space(self): - return {'action': self._env.action_space} - - def step(self, action): - image, reward, done, info = self._env.step(action['action']) - if self._grayscale: - image = image[..., None] - return { - 'image': image, - 'ram': self._env.env._get_ram(), - 'reward': reward, - 'is_first': False, - 'is_last': done, - 'is_terminal': done, - } - - def reset(self): - with self.LOCK: - image = self._env.reset() - if self._grayscale: - image = image[..., None] - return { - 'image': image, - 'ram': self._env.env._get_ram(), - 'reward': 0.0, - 'is_first': True, - 'is_last': False, - 'is_terminal': False, - } - - def close(self): - return self._env.close() - - -class Crafter: - - def __init__(self, outdir=None, reward=True, seed=None): - import crafter - self._env = crafter.Env(reward=reward, seed=seed) - self._env = crafter.Recorder( - self._env, outdir, - save_stats=True, - save_video=False, - save_episode=False, - ) - self._achievements = crafter.constants.achievements.copy() - - @property - def obs_space(self): - spaces = { - 'image': self._env.observation_space, - 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), - 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'log_reward': gym.spaces.Box(-np.inf, np.inf, (), np.float32), - } - spaces.update({ - f'log_achievement_{k}': gym.spaces.Box(0, 2 ** 31 - 1, (), np.int32) - for k in self._achievements}) - return spaces - - @property - def act_space(self): - return {'action': self._env.action_space} - - def step(self, action): - image, reward, done, info = self._env.step(action['action']) - obs = { - 'image': image, - 'reward': reward, - 'is_first': False, - 'is_last': done, - 'is_terminal': info['discount'] == 0, - 'log_reward': info['reward'], - } - obs.update({ - f'log_achievement_{k}': v - for k, v in info['achievements'].items()}) - return obs - - def reset(self): - obs = { - 'image': self._env.reset(), - 'reward': 0.0, - 'is_first': True, - 'is_last': False, - 'is_terminal': False, - 'log_reward': 0.0, - } - obs.update({ - f'log_achievement_{k}': 0 - for k in self._achievements}) - return obs - - -class Dummy: - - def __init__(self): - pass - - @property - def obs_space(self): - return { - 'image': gym.spaces.Box(0, 255, (64, 64, 3), dtype=np.uint8), - 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), - 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), - } - - @property - def act_space(self): - return {'action': gym.spaces.Box(-1, 1, (6,), dtype=np.float32)} - - def step(self, action): - return { - 'image': np.zeros((64, 64, 3)), - 'reward': 0.0, - 'is_first': False, - 'is_last': False, - 'is_terminal': False, - } - - def reset(self): - return { - 'image': np.zeros((64, 64, 3)), - 'reward': 0.0, - 'is_first': True, - 'is_last': False, - 'is_terminal': False, - } - - -class TimeLimit: - - def __init__(self, env, duration): - self._env = env - self._duration = duration - self._step = None - - def __getattr__(self, name): - if name.startswith('__'): - raise AttributeError(name) - try: - return getattr(self._env, name) - except AttributeError: - raise ValueError(name) - - def step(self, action): - assert self._step is not None, 'Must reset environment.' - obs = self._env.step(action) - self._step += 1 - if self._duration and self._step >= self._duration: - obs['is_last'] = True - self._step = None - return obs - - def reset(self): - self._step = 0 - return self._env.reset() - - -class NormalizeAction: - - def __init__(self, env, key='action'): - self._env = env - self._key = key - space = env.act_space[key] - self._mask = np.isfinite(space.low) & np.isfinite(space.high) - self._low = np.where(self._mask, space.low, -1) - self._high = np.where(self._mask, space.high, 1) - - def __getattr__(self, name): - if name.startswith('__'): - raise AttributeError(name) - try: - return getattr(self._env, name) - except AttributeError: - raise ValueError(name) - - @property - def act_space(self): - low = np.where(self._mask, -np.ones_like(self._low), self._low) - high = np.where(self._mask, np.ones_like(self._low), self._high) - space = gym.spaces.Box(low, high, dtype=np.float32) - return {**self._env.act_space, self._key: space} - - def step(self, action): - orig = (action[self._key] + 1) / 2 * (self._high - self._low) + self._low - orig = np.where(self._mask, orig, action[self._key]) - return self._env.step({**action, self._key: orig}) - - -class OneHotAction: - - def __init__(self, env, key='action'): - assert hasattr(env.act_space[key], 'n') - self._env = env - self._key = key - self._random = np.random.RandomState() - - def __getattr__(self, name): - if name.startswith('__'): - raise AttributeError(name) - try: - return getattr(self._env, name) - except AttributeError: - raise ValueError(name) - - @property - def act_space(self): - shape = (self._env.act_space[self._key].n,) - space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) - space.sample = self._sample_action - space.n = shape[0] - return {**self._env.act_space, self._key: space} - - def step(self, action): - index = np.argmax(action[self._key]).astype(int) - reference = np.zeros_like(action[self._key]) - reference[index] = 1 - if not np.allclose(reference, action[self._key]): - raise ValueError(f'Invalid one-hot action:\n{action}') - return self._env.step({**action, self._key: index}) - - def reset(self): - return self._env.reset() - - def _sample_action(self): - actions = self._env.act_space.n - index = self._random.randint(0, actions) - reference = np.zeros(actions, dtype=np.float32) - reference[index] = 1.0 - return reference - - -class ResizeImage: - - def __init__(self, env, size=(64, 64)): - self._env = env - self._size = size - self._keys = [ - k for k, v in env.obs_space.items() - if len(v.shape) > 1 and v.shape[:2] != size] - print(f'Resizing keys {",".join(self._keys)} to {self._size}.') - if self._keys: - from PIL import Image - self._Image = Image - - def __getattr__(self, name): - if name.startswith('__'): - raise AttributeError(name) - try: - return getattr(self._env, name) - except AttributeError: - raise ValueError(name) - - @property - def obs_space(self): - spaces = self._env.obs_space - for key in self._keys: - shape = self._size + spaces[key].shape[2:] - spaces[key] = gym.spaces.Box(0, 255, shape, np.uint8) - return spaces - - def step(self, action): - obs = self._env.step(action) - for key in self._keys: - obs[key] = self._resize(obs[key]) - return obs - - def reset(self): - obs = self._env.reset() - for key in self._keys: - obs[key] = self._resize(obs[key]) - return obs - - def _resize(self, image): - image = self._Image.fromarray(image) - image = image.resize(self._size, self._Image.NEAREST) - image = np.array(image) - return image - - -class RenderImage: - - def __init__(self, env, key='image'): - self._env = env - self._key = key - self._shape = self._env.render().shape - - def __getattr__(self, name): - if name.startswith('__'): - raise AttributeError(name) - try: - return getattr(self._env, name) - except AttributeError: - raise ValueError(name) - - @property - def obs_space(self): - spaces = self._env.obs_space - spaces[self._key] = gym.spaces.Box(0, 255, self._shape, np.uint8) - return spaces - - def step(self, action): - obs = self._env.step(action) - obs[self._key] = self._env.render('rgb_array') - return obs - - def reset(self): - obs = self._env.reset() - obs[self._key] = self._env.render('rgb_array') - return obs - - -class Async: - - # Message types for communication via the pipe. - _ACCESS = 1 - _CALL = 2 - _RESULT = 3 - _CLOSE = 4 - _EXCEPTION = 5 - - def __init__(self, constructor, strategy='thread'): - self._pickled_ctor = cloudpickle.dumps(constructor) - if strategy == 'process': - import multiprocessing as mp - context = mp.get_context('spawn') - elif strategy == 'thread': - import multiprocessing.dummy as context - else: - raise NotImplementedError(strategy) - self._strategy = strategy - self._conn, conn = context.Pipe() - self._process = context.Process(target=self._worker, args=(conn,)) - atexit.register(self.close) - self._process.start() - self._receive() # Ready. - self._obs_space = None - self._act_space = None - - def access(self, name): - self._conn.send((self._ACCESS, name)) - return self._receive - - def call(self, name, *args, **kwargs): - payload = name, args, kwargs - self._conn.send((self._CALL, payload)) - return self._receive - - def close(self): - try: - self._conn.send((self._CLOSE, None)) - self._conn.close() - except IOError: - pass # The connection was already closed. - self._process.join(5) - - @property - def obs_space(self): - if not self._obs_space: - self._obs_space = self.access('obs_space')() - return self._obs_space - - @property - def act_space(self): - if not self._act_space: - self._act_space = self.access('act_space')() - return self._act_space - - def step(self, action, blocking=False): - promise = self.call('step', action) - if blocking: - return promise() - else: - return promise - - def reset(self, blocking=False): - promise = self.call('reset') - if blocking: - return promise() - else: - return promise - - def _receive(self): - try: - message, payload = self._conn.recv() - except (OSError, EOFError): - raise RuntimeError('Lost connection to environment worker.') - # Re-raise exceptions in the main process. - if message == self._EXCEPTION: - stacktrace = payload - raise Exception(stacktrace) - if message == self._RESULT: - return payload - raise KeyError('Received message of unexpected type {}'.format(message)) - - def _worker(self, conn): - try: - ctor = cloudpickle.loads(self._pickled_ctor) - env = ctor() - conn.send((self._RESULT, None)) # Ready. - while True: - try: - # Only block for short times to have keyboard exceptions be raised. - if not conn.poll(0.1): - continue - message, payload = conn.recv() - except (EOFError, KeyboardInterrupt): - break - if message == self._ACCESS: - name = payload - result = getattr(env, name) - conn.send((self._RESULT, result)) - continue - if message == self._CALL: - name, args, kwargs = payload - result = getattr(env, name)(*args, **kwargs) - conn.send((self._RESULT, result)) - continue - if message == self._CLOSE: - break - raise KeyError('Received message of unknown type {}'.format(message)) - except Exception: - stacktrace = ''.join(traceback.format_exception(*sys.exc_info())) - print('Error in environment process: {}'.format(stacktrace)) - conn.send((self._EXCEPTION, stacktrace)) - finally: - try: - conn.close() - except IOError: - pass # The connection was already closed. diff --git a/envs/wrappers.py b/envs/wrappers.py index eacb173..03ff649 100644 --- a/envs/wrappers.py +++ b/envs/wrappers.py @@ -96,7 +96,6 @@ class TimeLimit: def step(self, action): assert self._step is not None, "Must reset environment." obs, reward, done, info = self._env.step(action) - # teets = self._env.step(action) self._step += 1 if self._step >= self._duration: done = True @@ -151,6 +150,41 @@ class OneHotAction: space.discrete = True return space + def step(self, action): + index = np.argmax(action).astype(int) + reference = np.zeros_like(action) + reference[index] = 1 + if not np.allclose(reference, action): + raise ValueError(f"Invalid one-hot action:\n{action}") + return self._env.step(index) + + def reset(self): + return self._env.reset() + + def _sample_action(self): + actions = self._env.action_space.n + index = self._random.randint(0, actions) + reference = np.zeros(actions, dtype=np.float32) + reference[index] = 1.0 + return reference + +class OneHotAction2: + def __init__(self, env): + assert isinstance(env.action_space, gym.spaces.Discrete) + self._env = env + self._random = np.random.RandomState() + + def __getattr__(self, name): + return getattr(self._env, name) + + @property + def action_space(self): + shape = (self._env.action_space.n,) + space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) + space.sample = self._sample_action + space.discrete = True + return space + def step(self, action): index = np.argmax(action).astype(int) reference = np.zeros_like(action) @@ -169,7 +203,6 @@ class OneHotAction: reference[index] = 1.0 return reference - class RewardObs: def __init__(self, env): self._env = env From 152415f32e640ce8ab99d7ea4f8dba03a7fcaadf Mon Sep 17 00:00:00 2001 From: zdx <179363811@qq.com> Date: Sat, 17 Jun 2023 23:59:05 +0800 Subject: [PATCH 6/7] mem maze env ok 1.1 --- configs.yaml | 2 +- dreamer.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/configs.yaml b/configs.yaml index 4db75dc..4750ee3 100644 --- a/configs.yaml +++ b/configs.yaml @@ -157,7 +157,7 @@ debug: batch_length: 20 mazegym: - task: 9 + task: '9' steps: 1e6 action_repeat: 2 diff --git a/dreamer.py b/dreamer.py index c20e8ca..3e90050 100644 --- a/dreamer.py +++ b/dreamer.py @@ -212,17 +212,15 @@ def make_env(config, logger, mode, train_eps, eval_eps): env = wrappers.OneHotAction(env) elif suite == "mazegym": import gym - if task == 9: + if task == '9': env = gym.make('memory_maze:MemoryMaze-9x9-v0') - elif task == 15: + elif task == '15': env = gym.make('memory_maze:MemoryMaze-15x15-v0') else: raise NotImplementedError(suite) from envs.memmazeEnv import MZGymWrapper env = MZGymWrapper(env) - #from envs.memmazeEnv import OneHotAction as OneHotAction2 env = wrappers.OneHotAction2(env) - else: raise NotImplementedError(suite) env = wrappers.TimeLimit(env, config.time_limit) From 8e005afde55da3c2cdec456218c055501392d067 Mon Sep 17 00:00:00 2001 From: zdx <179363811@qq.com> Date: Sun, 18 Jun 2023 09:16:32 +0800 Subject: [PATCH 7/7] mem maze env ok 1.2 --- configs.yaml | 6 +++-- dreamer.py | 12 ++++----- envs/{memmazeEnv.py => memorymaze.py} | 6 ++--- envs/wrappers.py | 38 ++------------------------- tools.py | 2 +- 5 files changed, 15 insertions(+), 49 deletions(-) rename envs/{memmazeEnv.py => memorymaze.py} (95%) diff --git a/configs.yaml b/configs.yaml index 4750ee3..18e96a1 100644 --- a/configs.yaml +++ b/configs.yaml @@ -156,8 +156,10 @@ debug: batch_size: 10 batch_length: 20 -mazegym: - task: '9' +MemoryMaze: + actor_dist: 'onehot' + imag_gradient: 'reinforce' + task: '9x9' steps: 1e6 action_repeat: 2 diff --git a/dreamer.py b/dreamer.py index 3e90050..24750d1 100644 --- a/dreamer.py +++ b/dreamer.py @@ -210,17 +210,17 @@ def make_env(config, logger, mode, train_eps, eval_eps): task, mode if "train" in mode else "test", config.action_repeat ) env = wrappers.OneHotAction(env) - elif suite == "mazegym": + elif suite == "MemoryMaze": import gym - if task == '9': + if task == '9x9': env = gym.make('memory_maze:MemoryMaze-9x9-v0') - elif task == '15': + elif task == '15x15': env = gym.make('memory_maze:MemoryMaze-15x15-v0') else: raise NotImplementedError(suite) - from envs.memmazeEnv import MZGymWrapper - env = MZGymWrapper(env) - env = wrappers.OneHotAction2(env) + from envs.memorymaze import MemoryMaze + env = MemoryMaze(env) + env = wrappers.OneHotAction(env) else: raise NotImplementedError(suite) env = wrappers.TimeLimit(env, config.time_limit) diff --git a/envs/memmazeEnv.py b/envs/memorymaze.py similarity index 95% rename from envs/memmazeEnv.py rename to envs/memorymaze.py index 980f805..a194368 100644 --- a/envs/memmazeEnv.py +++ b/envs/memorymaze.py @@ -1,8 +1,6 @@ import atexit import os import sys -import threading -import traceback import cloudpickle import gym @@ -10,7 +8,7 @@ import numpy as np ###from tf dreamerv2 code -class MZGymWrapper: +class MemoryMaze: def __init__(self, env, obs_key='image', act_key='action', size=(64, 64)): self._env = env @@ -74,7 +72,7 @@ class MZGymWrapper: # obs['reward'] = float(reward) obs['is_first'] = False obs['is_last'] = done - obs['is_terminal'] = info.get('is_terminal', done) + obs['is_terminal'] = info.get('is_terminal', False) return obs, reward, done, info def reset(self): diff --git a/envs/wrappers.py b/envs/wrappers.py index 03ff649..1a4a58b 100644 --- a/envs/wrappers.py +++ b/envs/wrappers.py @@ -77,8 +77,8 @@ class CollectDataset: dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision] elif np.issubdtype(value.dtype, np.uint8): dtype = np.uint8 - elif np.issubdtype(value.dtype, np.bool_): - dtype = np.bool_ + elif np.issubdtype(value.dtype, np.bool): + dtype = np.bool else: raise NotImplementedError(value.dtype) return value.astype(dtype) @@ -168,40 +168,6 @@ class OneHotAction: reference[index] = 1.0 return reference -class OneHotAction2: - def __init__(self, env): - assert isinstance(env.action_space, gym.spaces.Discrete) - self._env = env - self._random = np.random.RandomState() - - def __getattr__(self, name): - return getattr(self._env, name) - - @property - def action_space(self): - shape = (self._env.action_space.n,) - space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) - space.sample = self._sample_action - space.discrete = True - return space - - def step(self, action): - index = np.argmax(action).astype(int) - reference = np.zeros_like(action) - reference[index] = 1 - # if not np.allclose(reference, action): - # raise ValueError(f"Invalid one-hot action:\n{action}") - return self._env.step(index) - - def reset(self): - return self._env.reset() - - def _sample_action(self): - actions = self._env.action_space.n - index = self._random.randint(0, actions) - reference = np.zeros(actions, dtype=np.float32) - reference[index] = 1.0 - return reference class RewardObs: def __init__(self, env): diff --git a/tools.py b/tools.py index 752b786..bc46903 100644 --- a/tools.py +++ b/tools.py @@ -127,7 +127,7 @@ def simulate(agent, envs, steps=0, episodes=0, state=None): # Initialize or unpack simulation state. if state is None: step, episode = 0, 0 - done = np.ones(len(envs), np.bool_) + done = np.ones(len(envs), np.bool) length = np.zeros(len(envs), np.int32) obs = [None] * len(envs) agent_state = None