env v0.12

This commit is contained in:
张德祥 2023-06-13 21:39:04 +08:00
parent 5038a91aad
commit b9120a7440
4 changed files with 24 additions and 8 deletions

View File

@ -25,7 +25,7 @@ defaults:
action_repeat: 2 action_repeat: 2
time_limit: 1000 time_limit: 1000
grayscale: False grayscale: False
prefill: 2500 prefill: 250 #0
eval_noise: 0.0 eval_noise: 0.0
reward_EMA: True reward_EMA: True

View File

@ -215,7 +215,7 @@ def make_env(config, logger, mode, train_eps, eval_eps):
env = gym.make('memory_maze:MemoryMaze-9x9-v0') env = gym.make('memory_maze:MemoryMaze-9x9-v0')
from envs.memmazeEnv import MZGymWrapper from envs.memmazeEnv import MZGymWrapper
env = MZGymWrapper(env) env = MZGymWrapper(env)
#from envs.memmazeEnv import OneHotAction as OneHotAction2
env = wrappers.OneHotAction(env) env = wrappers.OneHotAction(env)
elif suite == "---------mazed": elif suite == "---------mazed":
from memory_maze import tasks from memory_maze import tasks

View File

@ -47,17 +47,32 @@ class MZGymWrapper:
else: else:
return {self._act_key: self._env.action_space} return {self._act_key: self._env.action_space}
@property
def observation_space(self):
img_shape = self._size + ((1,) if self._gray else (3,))
return gym.spaces.Dict(
{
"image": gym.spaces.Box(0, 255, img_shape, np.uint8),
}
)
@property
def action_space(self):
space = self._env.action_space
space.discrete = True
return space
def step(self, action): def step(self, action):
if not self._act_is_dict: # if not self._act_is_dict:
action = action[self._act_key] # action = action[self._act_key]
obs, reward, done, info = self._env.step(action) obs, reward, done, info = self._env.step(action)
if not self._obs_is_dict: if not self._obs_is_dict:
obs = {self._obs_key: obs} obs = {self._obs_key: obs}
obs['reward'] = float(reward) # obs['reward'] = float(reward)
obs['is_first'] = False obs['is_first'] = False
obs['is_last'] = done obs['is_last'] = done
obs['is_terminal'] = info.get('is_terminal', done) obs['is_terminal'] = info.get('is_terminal', done)
return obs return obs, reward, done, info
def reset(self): def reset(self):
obs = self._env.reset() obs = self._env.reset()

View File

@ -77,8 +77,8 @@ class CollectDataset:
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision] dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
elif np.issubdtype(value.dtype, np.uint8): elif np.issubdtype(value.dtype, np.uint8):
dtype = np.uint8 dtype = np.uint8
elif np.issubdtype(value.dtype, np.bool): elif np.issubdtype(value.dtype, np.bool_):
dtype = np.bool dtype = np.bool_
else: else:
raise NotImplementedError(value.dtype) raise NotImplementedError(value.dtype)
return value.astype(dtype) return value.astype(dtype)
@ -96,6 +96,7 @@ class TimeLimit:
def step(self, action): def step(self, action):
assert self._step is not None, "Must reset environment." assert self._step is not None, "Must reset environment."
obs, reward, done, info = self._env.step(action) obs, reward, done, info = self._env.step(action)
# teets = self._env.step(action)
self._step += 1 self._step += 1
if self._step >= self._duration: if self._step >= self._duration:
done = True done = True