env v0.12
This commit is contained in:
parent
5038a91aad
commit
b9120a7440
@ -25,7 +25,7 @@ defaults:
|
|||||||
action_repeat: 2
|
action_repeat: 2
|
||||||
time_limit: 1000
|
time_limit: 1000
|
||||||
grayscale: False
|
grayscale: False
|
||||||
prefill: 2500
|
prefill: 250 #0
|
||||||
eval_noise: 0.0
|
eval_noise: 0.0
|
||||||
reward_EMA: True
|
reward_EMA: True
|
||||||
|
|
||||||
|
|||||||
@ -215,7 +215,7 @@ def make_env(config, logger, mode, train_eps, eval_eps):
|
|||||||
env = gym.make('memory_maze:MemoryMaze-9x9-v0')
|
env = gym.make('memory_maze:MemoryMaze-9x9-v0')
|
||||||
from envs.memmazeEnv import MZGymWrapper
|
from envs.memmazeEnv import MZGymWrapper
|
||||||
env = MZGymWrapper(env)
|
env = MZGymWrapper(env)
|
||||||
|
#from envs.memmazeEnv import OneHotAction as OneHotAction2
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "---------mazed":
|
elif suite == "---------mazed":
|
||||||
from memory_maze import tasks
|
from memory_maze import tasks
|
||||||
|
|||||||
@ -47,17 +47,32 @@ class MZGymWrapper:
|
|||||||
else:
|
else:
|
||||||
return {self._act_key: self._env.action_space}
|
return {self._act_key: self._env.action_space}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_space(self):
|
||||||
|
img_shape = self._size + ((1,) if self._gray else (3,))
|
||||||
|
return gym.spaces.Dict(
|
||||||
|
{
|
||||||
|
"image": gym.spaces.Box(0, 255, img_shape, np.uint8),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_space(self):
|
||||||
|
space = self._env.action_space
|
||||||
|
space.discrete = True
|
||||||
|
return space
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
if not self._act_is_dict:
|
# if not self._act_is_dict:
|
||||||
action = action[self._act_key]
|
# action = action[self._act_key]
|
||||||
obs, reward, done, info = self._env.step(action)
|
obs, reward, done, info = self._env.step(action)
|
||||||
if not self._obs_is_dict:
|
if not self._obs_is_dict:
|
||||||
obs = {self._obs_key: obs}
|
obs = {self._obs_key: obs}
|
||||||
obs['reward'] = float(reward)
|
# obs['reward'] = float(reward)
|
||||||
obs['is_first'] = False
|
obs['is_first'] = False
|
||||||
obs['is_last'] = done
|
obs['is_last'] = done
|
||||||
obs['is_terminal'] = info.get('is_terminal', done)
|
obs['is_terminal'] = info.get('is_terminal', done)
|
||||||
return obs
|
return obs, reward, done, info
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
obs = self._env.reset()
|
obs = self._env.reset()
|
||||||
|
|||||||
@ -77,8 +77,8 @@ class CollectDataset:
|
|||||||
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
|
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
|
||||||
elif np.issubdtype(value.dtype, np.uint8):
|
elif np.issubdtype(value.dtype, np.uint8):
|
||||||
dtype = np.uint8
|
dtype = np.uint8
|
||||||
elif np.issubdtype(value.dtype, np.bool):
|
elif np.issubdtype(value.dtype, np.bool_):
|
||||||
dtype = np.bool
|
dtype = np.bool_
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(value.dtype)
|
raise NotImplementedError(value.dtype)
|
||||||
return value.astype(dtype)
|
return value.astype(dtype)
|
||||||
@ -96,6 +96,7 @@ class TimeLimit:
|
|||||||
def step(self, action):
|
def step(self, action):
|
||||||
assert self._step is not None, "Must reset environment."
|
assert self._step is not None, "Must reset environment."
|
||||||
obs, reward, done, info = self._env.step(action)
|
obs, reward, done, info = self._env.step(action)
|
||||||
|
# teets = self._env.step(action)
|
||||||
self._step += 1
|
self._step += 1
|
||||||
if self._step >= self._duration:
|
if self._step >= self._duration:
|
||||||
done = True
|
done = True
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user