env v0.13

This commit is contained in:
张德祥 2023-06-14 20:22:17 +08:00
parent b9120a7440
commit 1cf0149c10
4 changed files with 10 additions and 8 deletions

View File

@ -25,7 +25,7 @@ defaults:
action_repeat: 2
time_limit: 1000
grayscale: False
prefill: 250 #0
prefill: 3500
eval_noise: 0.0
reward_EMA: True
@ -150,7 +150,7 @@ atari100k:
mazed:
task: "memmaze_9_9"
steps: 5e4
steps: 1e6
action_repeat: 2
debug:
@ -162,12 +162,12 @@ debug:
mazegym:
#task: "memory_maze:MemoryMaze-9x9-v0"
steps: 5e4
steps: 1e6
action_repeat: 2
mazedeepm:
task: "memmaze_9_9"
steps: 5e4
steps: 1e6
action_repeat: 2

View File

@ -11,12 +11,14 @@ import numpy as np
class MZGymWrapper:
def __init__(self, env, obs_key='image', act_key='action'):
def __init__(self, env, obs_key='image', act_key='action', size=(64, 64)):
self._env = env
self._obs_is_dict = hasattr(self._env.observation_space, 'spaces')
self._act_is_dict = hasattr(self._env.action_space, 'spaces')
self._obs_key = obs_key
self._act_key = act_key
self._size = size
self._gray = False
def __getattr__(self, name):
if name.startswith('__'):

View File

@ -155,8 +155,8 @@ class OneHotAction:
index = np.argmax(action).astype(int)
reference = np.zeros_like(action)
reference[index] = 1
if not np.allclose(reference, action):
raise ValueError(f"Invalid one-hot action:\n{action}")
# if not np.allclose(reference, action):
# raise ValueError(f"Invalid one-hot action:\n{action}")
return self._env.step(index)
def reset(self):

View File

@ -127,7 +127,7 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
# Initialize or unpack simulation state.
if state is None:
step, episode = 0, 0
done = np.ones(len(envs), np.bool)
done = np.ones(len(envs), np.bool_)
length = np.zeros(len(envs), np.int32)
obs = [None] * len(envs)
agent_state = None