mem maze env ok 1.2

This commit is contained in:
zdx 2023-06-18 09:16:32 +08:00
parent 152415f32e
commit 8e005afde5
5 changed files with 15 additions and 49 deletions

View File

@ -156,8 +156,10 @@ debug:
batch_size: 10
batch_length: 20
mazegym:
task: '9'
MemoryMaze:
actor_dist: 'onehot'
imag_gradient: 'reinforce'
task: '9x9'
steps: 1e6
action_repeat: 2

View File

@ -210,17 +210,17 @@ def make_env(config, logger, mode, train_eps, eval_eps):
task, mode if "train" in mode else "test", config.action_repeat
)
env = wrappers.OneHotAction(env)
elif suite == "mazegym":
elif suite == "MemoryMaze":
import gym
if task == '9':
if task == '9x9':
env = gym.make('memory_maze:MemoryMaze-9x9-v0')
elif task == '15':
elif task == '15x15':
env = gym.make('memory_maze:MemoryMaze-15x15-v0')
else:
raise NotImplementedError(suite)
from envs.memmazeEnv import MZGymWrapper
env = MZGymWrapper(env)
env = wrappers.OneHotAction2(env)
from envs.memorymaze import MemoryMaze
env = MemoryMaze(env)
env = wrappers.OneHotAction(env)
else:
raise NotImplementedError(suite)
env = wrappers.TimeLimit(env, config.time_limit)

View File

@ -1,8 +1,6 @@
import atexit
import os
import sys
import threading
import traceback
import cloudpickle
import gym
@ -10,7 +8,7 @@ import numpy as np
###from tf dreamerv2 code
class MZGymWrapper:
class MemoryMaze:
def __init__(self, env, obs_key='image', act_key='action', size=(64, 64)):
self._env = env
@ -74,7 +72,7 @@ class MZGymWrapper:
# obs['reward'] = float(reward)
obs['is_first'] = False
obs['is_last'] = done
obs['is_terminal'] = info.get('is_terminal', done)
obs['is_terminal'] = info.get('is_terminal', False)
return obs, reward, done, info
def reset(self):

View File

@ -77,8 +77,8 @@ class CollectDataset:
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
elif np.issubdtype(value.dtype, np.uint8):
dtype = np.uint8
elif np.issubdtype(value.dtype, np.bool_):
dtype = np.bool_
elif np.issubdtype(value.dtype, np.bool):
dtype = np.bool
else:
raise NotImplementedError(value.dtype)
return value.astype(dtype)
@ -168,40 +168,6 @@ class OneHotAction:
reference[index] = 1.0
return reference
class OneHotAction2:
def __init__(self, env):
assert isinstance(env.action_space, gym.spaces.Discrete)
self._env = env
self._random = np.random.RandomState()
def __getattr__(self, name):
return getattr(self._env, name)
@property
def action_space(self):
shape = (self._env.action_space.n,)
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
space.sample = self._sample_action
space.discrete = True
return space
def step(self, action):
index = np.argmax(action).astype(int)
reference = np.zeros_like(action)
reference[index] = 1
# if not np.allclose(reference, action):
# raise ValueError(f"Invalid one-hot action:\n{action}")
return self._env.step(index)
def reset(self):
return self._env.reset()
def _sample_action(self):
actions = self._env.action_space.n
index = self._random.randint(0, actions)
reference = np.zeros(actions, dtype=np.float32)
reference[index] = 1.0
return reference
class RewardObs:
def __init__(self, env):

View File

@ -127,7 +127,7 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
# Initialize or unpack simulation state.
if state is None:
step, episode = 0, 0
done = np.ones(len(envs), np.bool_)
done = np.ones(len(envs), np.bool)
length = np.zeros(len(envs), np.int32)
obs = [None] * len(envs)
agent_state = None