env v0.12

This commit is contained in:
张德祥 2023-06-13 21:39:04 +08:00
parent 5038a91aad
commit b9120a7440
4 changed files with 24 additions and 8 deletions

View File

@ -25,7 +25,7 @@ defaults:
action_repeat: 2
time_limit: 1000
grayscale: False
prefill: 2500
prefill: 250 #0
eval_noise: 0.0
reward_EMA: True

View File

@ -215,7 +215,7 @@ def make_env(config, logger, mode, train_eps, eval_eps):
env = gym.make('memory_maze:MemoryMaze-9x9-v0')
from envs.memmazeEnv import MZGymWrapper
env = MZGymWrapper(env)
#from envs.memmazeEnv import OneHotAction as OneHotAction2
env = wrappers.OneHotAction(env)
elif suite == "---------mazed":
from memory_maze import tasks

View File

@ -47,17 +47,32 @@ class MZGymWrapper:
else:
return {self._act_key: self._env.action_space}
@property
def observation_space(self):
img_shape = self._size + ((1,) if self._gray else (3,))
return gym.spaces.Dict(
{
"image": gym.spaces.Box(0, 255, img_shape, np.uint8),
}
)
@property
def action_space(self):
space = self._env.action_space
space.discrete = True
return space
def step(self, action):
if not self._act_is_dict:
action = action[self._act_key]
# if not self._act_is_dict:
# action = action[self._act_key]
obs, reward, done, info = self._env.step(action)
if not self._obs_is_dict:
obs = {self._obs_key: obs}
obs['reward'] = float(reward)
# obs['reward'] = float(reward)
obs['is_first'] = False
obs['is_last'] = done
obs['is_terminal'] = info.get('is_terminal', done)
return obs
return obs, reward, done, info
def reset(self):
obs = self._env.reset()

View File

@ -77,8 +77,8 @@ class CollectDataset:
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
elif np.issubdtype(value.dtype, np.uint8):
dtype = np.uint8
elif np.issubdtype(value.dtype, np.bool):
dtype = np.bool
elif np.issubdtype(value.dtype, np.bool_):
dtype = np.bool_
else:
raise NotImplementedError(value.dtype)
return value.astype(dtype)
@ -96,6 +96,7 @@ class TimeLimit:
def step(self, action):
assert self._step is not None, "Must reset environment."
obs, reward, done, info = self._env.step(action)
# teets = self._env.step(action)
self._step += 1
if self._step >= self._duration:
done = True