env v0.13
This commit is contained in:
parent
b9120a7440
commit
1cf0149c10
@ -25,7 +25,7 @@ defaults:
|
||||
action_repeat: 2
|
||||
time_limit: 1000
|
||||
grayscale: False
|
||||
prefill: 250 #0
|
||||
prefill: 3500
|
||||
eval_noise: 0.0
|
||||
reward_EMA: True
|
||||
|
||||
@ -150,7 +150,7 @@ atari100k:
|
||||
|
||||
mazed:
|
||||
task: "memmaze_9_9"
|
||||
steps: 5e4
|
||||
steps: 1e6
|
||||
action_repeat: 2
|
||||
|
||||
debug:
|
||||
@ -162,12 +162,12 @@ debug:
|
||||
|
||||
mazegym:
|
||||
#task: "memory_maze:MemoryMaze-9x9-v0"
|
||||
steps: 5e4
|
||||
steps: 1e6
|
||||
action_repeat: 2
|
||||
|
||||
mazedeepm:
|
||||
task: "memmaze_9_9"
|
||||
steps: 5e4
|
||||
steps: 1e6
|
||||
action_repeat: 2
|
||||
|
||||
|
||||
|
||||
@ -11,12 +11,14 @@ import numpy as np
|
||||
|
||||
class MZGymWrapper:
|
||||
|
||||
def __init__(self, env, obs_key='image', act_key='action'):
|
||||
def __init__(self, env, obs_key='image', act_key='action', size=(64, 64)):
|
||||
self._env = env
|
||||
self._obs_is_dict = hasattr(self._env.observation_space, 'spaces')
|
||||
self._act_is_dict = hasattr(self._env.action_space, 'spaces')
|
||||
self._obs_key = obs_key
|
||||
self._act_key = act_key
|
||||
self._size = size
|
||||
self._gray = False
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name.startswith('__'):
|
||||
|
||||
@ -155,8 +155,8 @@ class OneHotAction:
|
||||
index = np.argmax(action).astype(int)
|
||||
reference = np.zeros_like(action)
|
||||
reference[index] = 1
|
||||
if not np.allclose(reference, action):
|
||||
raise ValueError(f"Invalid one-hot action:\n{action}")
|
||||
# if not np.allclose(reference, action):
|
||||
# raise ValueError(f"Invalid one-hot action:\n{action}")
|
||||
return self._env.step(index)
|
||||
|
||||
def reset(self):
|
||||
|
||||
2
tools.py
2
tools.py
@ -127,7 +127,7 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
|
||||
# Initialize or unpack simulation state.
|
||||
if state is None:
|
||||
step, episode = 0, 0
|
||||
done = np.ones(len(envs), np.bool)
|
||||
done = np.ones(len(envs), np.bool_)
|
||||
length = np.zeros(len(envs), np.int32)
|
||||
obs = [None] * len(envs)
|
||||
agent_state = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user