diff --git a/configs.yaml b/configs.yaml index 65cf26e..4db75dc 100644 --- a/configs.yaml +++ b/configs.yaml @@ -25,7 +25,7 @@ defaults: action_repeat: 2 time_limit: 1000 grayscale: False - prefill: 3500 + prefill: 2500 eval_noise: 0.0 reward_EMA: True @@ -148,10 +148,6 @@ atari100k: imag_gradient: 'reinforce' time_limit: 108000 -mazed: - task: "memmaze_9_9" - steps: 1e6 - action_repeat: 2 debug: debug: True @@ -161,13 +157,10 @@ debug: batch_length: 20 mazegym: - #task: "memory_maze:MemoryMaze-9x9-v0" - steps: 1e6 - action_repeat: 2 - -mazedeepm: - task: "memmaze_9_9" + task: 9 steps: 1e6 action_repeat: 2 + + diff --git a/dreamer.py b/dreamer.py index 04b46ab..c20e8ca 100644 --- a/dreamer.py +++ b/dreamer.py @@ -212,16 +212,17 @@ def make_env(config, logger, mode, train_eps, eval_eps): env = wrappers.OneHotAction(env) elif suite == "mazegym": import gym - env = gym.make('memory_maze:MemoryMaze-9x9-v0') + if task == 9: + env = gym.make('memory_maze:MemoryMaze-9x9-v0') + elif task == 15: + env = gym.make('memory_maze:MemoryMaze-15x15-v0') + else: + raise NotImplementedError(suite) from envs.memmazeEnv import MZGymWrapper env = MZGymWrapper(env) #from envs.memmazeEnv import OneHotAction as OneHotAction2 - env = wrappers.OneHotAction(env) - elif suite == "---------mazed": - from memory_maze import tasks - ## !!!!!!!!!!!!!!!!!!!!!!!! - env = tasks.memory_maze_9x9() - env = wrappers.OneHotAction(env) + env = wrappers.OneHotAction2(env) + else: raise NotImplementedError(suite) env = wrappers.TimeLimit(env, config.time_limit) diff --git a/envs/memmazeEnv.py b/envs/memmazeEnv.py index 81914fb..980f805 100644 --- a/envs/memmazeEnv.py +++ b/envs/memmazeEnv.py @@ -8,6 +8,7 @@ import cloudpickle import gym import numpy as np +###from tf dreamerv2 code class MZGymWrapper: @@ -86,571 +87,3 @@ class MZGymWrapper: obs['is_terminal'] = False return obs - -class DMC: - - def __init__(self, name, action_repeat=1, size=(64, 64), camera=None): - os.environ['MUJOCO_GL'] = 'egl' - domain, task = name.split('_', 1) - if domain == 'cup': # Only domain with multiple words. - domain = 'ball_in_cup' - if domain == 'manip': - from dm_control import manipulation - self._env = manipulation.load(task + '_vision') - elif domain == 'locom': - from dm_control.locomotion.examples import basic_rodent_2020 - self._env = getattr(basic_rodent_2020, task)() - else: - from dm_control import suite - self._env = suite.load(domain, task) - self._action_repeat = action_repeat - self._size = size - if camera in (-1, None): - camera = dict( - quadruped_walk=2, quadruped_run=2, quadruped_escape=2, - quadruped_fetch=2, locom_rodent_maze_forage=1, - locom_rodent_two_touch=1, - ).get(name, 0) - self._camera = camera - self._ignored_keys = [] - for key, value in self._env.observation_spec().items(): - if value.shape == (0,): - print(f"Ignoring empty observation key '{key}'.") - self._ignored_keys.append(key) - - @property - def obs_space(self): - spaces = { - 'image': gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8), - 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), - 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), - } - for key, value in self._env.observation_spec().items(): - if key in self._ignored_keys: - continue - if value.dtype == np.float64: - spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, np.float32) - elif value.dtype == np.uint8: - spaces[key] = gym.spaces.Box(0, 255, value.shape, np.uint8) - else: - raise NotImplementedError(value.dtype) - return spaces - - @property - def act_space(self): - spec = self._env.action_spec() - action = gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) - return {'action': action} - - def step(self, action): - assert np.isfinite(action['action']).all(), action['action'] - reward = 0.0 - for _ in range(self._action_repeat): - time_step = self._env.step(action['action']) - reward += time_step.reward or 0.0 - if time_step.last(): - break - assert time_step.discount in (0, 1) - obs = { - 'reward': reward, - 'is_first': False, - 'is_last': time_step.last(), - 'is_terminal': time_step.discount == 0, - 'image': self._env.physics.render(*self._size, camera_id=self._camera), - } - obs.update({ - k: v for k, v in dict(time_step.observation).items() - if k not in self._ignored_keys}) - return obs - - def reset(self): - time_step = self._env.reset() - obs = { - 'reward': 0.0, - 'is_first': True, - 'is_last': False, - 'is_terminal': False, - 'image': self._env.physics.render(*self._size, camera_id=self._camera), - } - obs.update({ - k: v for k, v in dict(time_step.observation).items() - if k not in self._ignored_keys}) - return obs - - -class Atari: - - LOCK = threading.Lock() - - def __init__( - self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30, - life_done=False, sticky=True, all_actions=False): - assert size[0] == size[1] - import gym.wrappers - import gym.envs.atari - if name == 'james_bond': - name = 'jamesbond' - with self.LOCK: - env = gym.envs.atari.AtariEnv( - game=name, obs_type='image', frameskip=1, - repeat_action_probability=0.25 if sticky else 0.0, - full_action_space=all_actions) - # Avoid unnecessary rendering in inner env. - env._get_obs = lambda: None - # Tell wrapper that the inner env has no action repeat. - env.spec = gym.envs.registration.EnvSpec('NoFrameskip-v0') - self._env = gym.wrappers.AtariPreprocessing( - env, noops, action_repeat, size[0], life_done, grayscale) - self._size = size - self._grayscale = grayscale - - @property - def obs_space(self): - shape = self._size + (1 if self._grayscale else 3,) - return { - 'image': gym.spaces.Box(0, 255, shape, np.uint8), - 'ram': gym.spaces.Box(0, 255, (128,), np.uint8), - 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), - 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), - } - - @property - def act_space(self): - return {'action': self._env.action_space} - - def step(self, action): - image, reward, done, info = self._env.step(action['action']) - if self._grayscale: - image = image[..., None] - return { - 'image': image, - 'ram': self._env.env._get_ram(), - 'reward': reward, - 'is_first': False, - 'is_last': done, - 'is_terminal': done, - } - - def reset(self): - with self.LOCK: - image = self._env.reset() - if self._grayscale: - image = image[..., None] - return { - 'image': image, - 'ram': self._env.env._get_ram(), - 'reward': 0.0, - 'is_first': True, - 'is_last': False, - 'is_terminal': False, - } - - def close(self): - return self._env.close() - - -class Crafter: - - def __init__(self, outdir=None, reward=True, seed=None): - import crafter - self._env = crafter.Env(reward=reward, seed=seed) - self._env = crafter.Recorder( - self._env, outdir, - save_stats=True, - save_video=False, - save_episode=False, - ) - self._achievements = crafter.constants.achievements.copy() - - @property - def obs_space(self): - spaces = { - 'image': self._env.observation_space, - 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), - 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'log_reward': gym.spaces.Box(-np.inf, np.inf, (), np.float32), - } - spaces.update({ - f'log_achievement_{k}': gym.spaces.Box(0, 2 ** 31 - 1, (), np.int32) - for k in self._achievements}) - return spaces - - @property - def act_space(self): - return {'action': self._env.action_space} - - def step(self, action): - image, reward, done, info = self._env.step(action['action']) - obs = { - 'image': image, - 'reward': reward, - 'is_first': False, - 'is_last': done, - 'is_terminal': info['discount'] == 0, - 'log_reward': info['reward'], - } - obs.update({ - f'log_achievement_{k}': v - for k, v in info['achievements'].items()}) - return obs - - def reset(self): - obs = { - 'image': self._env.reset(), - 'reward': 0.0, - 'is_first': True, - 'is_last': False, - 'is_terminal': False, - 'log_reward': 0.0, - } - obs.update({ - f'log_achievement_{k}': 0 - for k in self._achievements}) - return obs - - -class Dummy: - - def __init__(self): - pass - - @property - def obs_space(self): - return { - 'image': gym.spaces.Box(0, 255, (64, 64, 3), dtype=np.uint8), - 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), - 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), - 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), - } - - @property - def act_space(self): - return {'action': gym.spaces.Box(-1, 1, (6,), dtype=np.float32)} - - def step(self, action): - return { - 'image': np.zeros((64, 64, 3)), - 'reward': 0.0, - 'is_first': False, - 'is_last': False, - 'is_terminal': False, - } - - def reset(self): - return { - 'image': np.zeros((64, 64, 3)), - 'reward': 0.0, - 'is_first': True, - 'is_last': False, - 'is_terminal': False, - } - - -class TimeLimit: - - def __init__(self, env, duration): - self._env = env - self._duration = duration - self._step = None - - def __getattr__(self, name): - if name.startswith('__'): - raise AttributeError(name) - try: - return getattr(self._env, name) - except AttributeError: - raise ValueError(name) - - def step(self, action): - assert self._step is not None, 'Must reset environment.' - obs = self._env.step(action) - self._step += 1 - if self._duration and self._step >= self._duration: - obs['is_last'] = True - self._step = None - return obs - - def reset(self): - self._step = 0 - return self._env.reset() - - -class NormalizeAction: - - def __init__(self, env, key='action'): - self._env = env - self._key = key - space = env.act_space[key] - self._mask = np.isfinite(space.low) & np.isfinite(space.high) - self._low = np.where(self._mask, space.low, -1) - self._high = np.where(self._mask, space.high, 1) - - def __getattr__(self, name): - if name.startswith('__'): - raise AttributeError(name) - try: - return getattr(self._env, name) - except AttributeError: - raise ValueError(name) - - @property - def act_space(self): - low = np.where(self._mask, -np.ones_like(self._low), self._low) - high = np.where(self._mask, np.ones_like(self._low), self._high) - space = gym.spaces.Box(low, high, dtype=np.float32) - return {**self._env.act_space, self._key: space} - - def step(self, action): - orig = (action[self._key] + 1) / 2 * (self._high - self._low) + self._low - orig = np.where(self._mask, orig, action[self._key]) - return self._env.step({**action, self._key: orig}) - - -class OneHotAction: - - def __init__(self, env, key='action'): - assert hasattr(env.act_space[key], 'n') - self._env = env - self._key = key - self._random = np.random.RandomState() - - def __getattr__(self, name): - if name.startswith('__'): - raise AttributeError(name) - try: - return getattr(self._env, name) - except AttributeError: - raise ValueError(name) - - @property - def act_space(self): - shape = (self._env.act_space[self._key].n,) - space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) - space.sample = self._sample_action - space.n = shape[0] - return {**self._env.act_space, self._key: space} - - def step(self, action): - index = np.argmax(action[self._key]).astype(int) - reference = np.zeros_like(action[self._key]) - reference[index] = 1 - if not np.allclose(reference, action[self._key]): - raise ValueError(f'Invalid one-hot action:\n{action}') - return self._env.step({**action, self._key: index}) - - def reset(self): - return self._env.reset() - - def _sample_action(self): - actions = self._env.act_space.n - index = self._random.randint(0, actions) - reference = np.zeros(actions, dtype=np.float32) - reference[index] = 1.0 - return reference - - -class ResizeImage: - - def __init__(self, env, size=(64, 64)): - self._env = env - self._size = size - self._keys = [ - k for k, v in env.obs_space.items() - if len(v.shape) > 1 and v.shape[:2] != size] - print(f'Resizing keys {",".join(self._keys)} to {self._size}.') - if self._keys: - from PIL import Image - self._Image = Image - - def __getattr__(self, name): - if name.startswith('__'): - raise AttributeError(name) - try: - return getattr(self._env, name) - except AttributeError: - raise ValueError(name) - - @property - def obs_space(self): - spaces = self._env.obs_space - for key in self._keys: - shape = self._size + spaces[key].shape[2:] - spaces[key] = gym.spaces.Box(0, 255, shape, np.uint8) - return spaces - - def step(self, action): - obs = self._env.step(action) - for key in self._keys: - obs[key] = self._resize(obs[key]) - return obs - - def reset(self): - obs = self._env.reset() - for key in self._keys: - obs[key] = self._resize(obs[key]) - return obs - - def _resize(self, image): - image = self._Image.fromarray(image) - image = image.resize(self._size, self._Image.NEAREST) - image = np.array(image) - return image - - -class RenderImage: - - def __init__(self, env, key='image'): - self._env = env - self._key = key - self._shape = self._env.render().shape - - def __getattr__(self, name): - if name.startswith('__'): - raise AttributeError(name) - try: - return getattr(self._env, name) - except AttributeError: - raise ValueError(name) - - @property - def obs_space(self): - spaces = self._env.obs_space - spaces[self._key] = gym.spaces.Box(0, 255, self._shape, np.uint8) - return spaces - - def step(self, action): - obs = self._env.step(action) - obs[self._key] = self._env.render('rgb_array') - return obs - - def reset(self): - obs = self._env.reset() - obs[self._key] = self._env.render('rgb_array') - return obs - - -class Async: - - # Message types for communication via the pipe. - _ACCESS = 1 - _CALL = 2 - _RESULT = 3 - _CLOSE = 4 - _EXCEPTION = 5 - - def __init__(self, constructor, strategy='thread'): - self._pickled_ctor = cloudpickle.dumps(constructor) - if strategy == 'process': - import multiprocessing as mp - context = mp.get_context('spawn') - elif strategy == 'thread': - import multiprocessing.dummy as context - else: - raise NotImplementedError(strategy) - self._strategy = strategy - self._conn, conn = context.Pipe() - self._process = context.Process(target=self._worker, args=(conn,)) - atexit.register(self.close) - self._process.start() - self._receive() # Ready. - self._obs_space = None - self._act_space = None - - def access(self, name): - self._conn.send((self._ACCESS, name)) - return self._receive - - def call(self, name, *args, **kwargs): - payload = name, args, kwargs - self._conn.send((self._CALL, payload)) - return self._receive - - def close(self): - try: - self._conn.send((self._CLOSE, None)) - self._conn.close() - except IOError: - pass # The connection was already closed. - self._process.join(5) - - @property - def obs_space(self): - if not self._obs_space: - self._obs_space = self.access('obs_space')() - return self._obs_space - - @property - def act_space(self): - if not self._act_space: - self._act_space = self.access('act_space')() - return self._act_space - - def step(self, action, blocking=False): - promise = self.call('step', action) - if blocking: - return promise() - else: - return promise - - def reset(self, blocking=False): - promise = self.call('reset') - if blocking: - return promise() - else: - return promise - - def _receive(self): - try: - message, payload = self._conn.recv() - except (OSError, EOFError): - raise RuntimeError('Lost connection to environment worker.') - # Re-raise exceptions in the main process. - if message == self._EXCEPTION: - stacktrace = payload - raise Exception(stacktrace) - if message == self._RESULT: - return payload - raise KeyError('Received message of unexpected type {}'.format(message)) - - def _worker(self, conn): - try: - ctor = cloudpickle.loads(self._pickled_ctor) - env = ctor() - conn.send((self._RESULT, None)) # Ready. - while True: - try: - # Only block for short times to have keyboard exceptions be raised. - if not conn.poll(0.1): - continue - message, payload = conn.recv() - except (EOFError, KeyboardInterrupt): - break - if message == self._ACCESS: - name = payload - result = getattr(env, name) - conn.send((self._RESULT, result)) - continue - if message == self._CALL: - name, args, kwargs = payload - result = getattr(env, name)(*args, **kwargs) - conn.send((self._RESULT, result)) - continue - if message == self._CLOSE: - break - raise KeyError('Received message of unknown type {}'.format(message)) - except Exception: - stacktrace = ''.join(traceback.format_exception(*sys.exc_info())) - print('Error in environment process: {}'.format(stacktrace)) - conn.send((self._EXCEPTION, stacktrace)) - finally: - try: - conn.close() - except IOError: - pass # The connection was already closed. diff --git a/envs/wrappers.py b/envs/wrappers.py index eacb173..03ff649 100644 --- a/envs/wrappers.py +++ b/envs/wrappers.py @@ -96,7 +96,6 @@ class TimeLimit: def step(self, action): assert self._step is not None, "Must reset environment." obs, reward, done, info = self._env.step(action) - # teets = self._env.step(action) self._step += 1 if self._step >= self._duration: done = True @@ -151,6 +150,41 @@ class OneHotAction: space.discrete = True return space + def step(self, action): + index = np.argmax(action).astype(int) + reference = np.zeros_like(action) + reference[index] = 1 + if not np.allclose(reference, action): + raise ValueError(f"Invalid one-hot action:\n{action}") + return self._env.step(index) + + def reset(self): + return self._env.reset() + + def _sample_action(self): + actions = self._env.action_space.n + index = self._random.randint(0, actions) + reference = np.zeros(actions, dtype=np.float32) + reference[index] = 1.0 + return reference + +class OneHotAction2: + def __init__(self, env): + assert isinstance(env.action_space, gym.spaces.Discrete) + self._env = env + self._random = np.random.RandomState() + + def __getattr__(self, name): + return getattr(self._env, name) + + @property + def action_space(self): + shape = (self._env.action_space.n,) + space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) + space.sample = self._sample_action + space.discrete = True + return space + def step(self, action): index = np.argmax(action).astype(int) reference = np.zeros_like(action) @@ -169,7 +203,6 @@ class OneHotAction: reference[index] = 1.0 return reference - class RewardObs: def __init__(self, env): self._env = env