mem maze env ok 1.2
This commit is contained in:
parent
152415f32e
commit
8e005afde5
@ -156,8 +156,10 @@ debug:
|
|||||||
batch_size: 10
|
batch_size: 10
|
||||||
batch_length: 20
|
batch_length: 20
|
||||||
|
|
||||||
mazegym:
|
MemoryMaze:
|
||||||
task: '9'
|
actor_dist: 'onehot'
|
||||||
|
imag_gradient: 'reinforce'
|
||||||
|
task: '9x9'
|
||||||
steps: 1e6
|
steps: 1e6
|
||||||
action_repeat: 2
|
action_repeat: 2
|
||||||
|
|
||||||
|
12
dreamer.py
12
dreamer.py
@ -210,17 +210,17 @@ def make_env(config, logger, mode, train_eps, eval_eps):
|
|||||||
task, mode if "train" in mode else "test", config.action_repeat
|
task, mode if "train" in mode else "test", config.action_repeat
|
||||||
)
|
)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "mazegym":
|
elif suite == "MemoryMaze":
|
||||||
import gym
|
import gym
|
||||||
if task == '9':
|
if task == '9x9':
|
||||||
env = gym.make('memory_maze:MemoryMaze-9x9-v0')
|
env = gym.make('memory_maze:MemoryMaze-9x9-v0')
|
||||||
elif task == '15':
|
elif task == '15x15':
|
||||||
env = gym.make('memory_maze:MemoryMaze-15x15-v0')
|
env = gym.make('memory_maze:MemoryMaze-15x15-v0')
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(suite)
|
raise NotImplementedError(suite)
|
||||||
from envs.memmazeEnv import MZGymWrapper
|
from envs.memorymaze import MemoryMaze
|
||||||
env = MZGymWrapper(env)
|
env = MemoryMaze(env)
|
||||||
env = wrappers.OneHotAction2(env)
|
env = wrappers.OneHotAction(env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(suite)
|
raise NotImplementedError(suite)
|
||||||
env = wrappers.TimeLimit(env, config.time_limit)
|
env = wrappers.TimeLimit(env, config.time_limit)
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
import atexit
|
import atexit
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import gym
|
import gym
|
||||||
@ -10,7 +8,7 @@ import numpy as np
|
|||||||
|
|
||||||
###from tf dreamerv2 code
|
###from tf dreamerv2 code
|
||||||
|
|
||||||
class MZGymWrapper:
|
class MemoryMaze:
|
||||||
|
|
||||||
def __init__(self, env, obs_key='image', act_key='action', size=(64, 64)):
|
def __init__(self, env, obs_key='image', act_key='action', size=(64, 64)):
|
||||||
self._env = env
|
self._env = env
|
||||||
@ -74,7 +72,7 @@ class MZGymWrapper:
|
|||||||
# obs['reward'] = float(reward)
|
# obs['reward'] = float(reward)
|
||||||
obs['is_first'] = False
|
obs['is_first'] = False
|
||||||
obs['is_last'] = done
|
obs['is_last'] = done
|
||||||
obs['is_terminal'] = info.get('is_terminal', done)
|
obs['is_terminal'] = info.get('is_terminal', False)
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
@ -77,8 +77,8 @@ class CollectDataset:
|
|||||||
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
|
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
|
||||||
elif np.issubdtype(value.dtype, np.uint8):
|
elif np.issubdtype(value.dtype, np.uint8):
|
||||||
dtype = np.uint8
|
dtype = np.uint8
|
||||||
elif np.issubdtype(value.dtype, np.bool_):
|
elif np.issubdtype(value.dtype, np.bool):
|
||||||
dtype = np.bool_
|
dtype = np.bool
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(value.dtype)
|
raise NotImplementedError(value.dtype)
|
||||||
return value.astype(dtype)
|
return value.astype(dtype)
|
||||||
@ -168,40 +168,6 @@ class OneHotAction:
|
|||||||
reference[index] = 1.0
|
reference[index] = 1.0
|
||||||
return reference
|
return reference
|
||||||
|
|
||||||
class OneHotAction2:
|
|
||||||
def __init__(self, env):
|
|
||||||
assert isinstance(env.action_space, gym.spaces.Discrete)
|
|
||||||
self._env = env
|
|
||||||
self._random = np.random.RandomState()
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self._env, name)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def action_space(self):
|
|
||||||
shape = (self._env.action_space.n,)
|
|
||||||
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
|
||||||
space.sample = self._sample_action
|
|
||||||
space.discrete = True
|
|
||||||
return space
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
index = np.argmax(action).astype(int)
|
|
||||||
reference = np.zeros_like(action)
|
|
||||||
reference[index] = 1
|
|
||||||
# if not np.allclose(reference, action):
|
|
||||||
# raise ValueError(f"Invalid one-hot action:\n{action}")
|
|
||||||
return self._env.step(index)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
return self._env.reset()
|
|
||||||
|
|
||||||
def _sample_action(self):
|
|
||||||
actions = self._env.action_space.n
|
|
||||||
index = self._random.randint(0, actions)
|
|
||||||
reference = np.zeros(actions, dtype=np.float32)
|
|
||||||
reference[index] = 1.0
|
|
||||||
return reference
|
|
||||||
|
|
||||||
class RewardObs:
|
class RewardObs:
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
|
2
tools.py
2
tools.py
@ -127,7 +127,7 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
|
|||||||
# Initialize or unpack simulation state.
|
# Initialize or unpack simulation state.
|
||||||
if state is None:
|
if state is None:
|
||||||
step, episode = 0, 0
|
step, episode = 0, 0
|
||||||
done = np.ones(len(envs), np.bool_)
|
done = np.ones(len(envs), np.bool)
|
||||||
length = np.zeros(len(envs), np.int32)
|
length = np.zeros(len(envs), np.int32)
|
||||||
obs = [None] * len(envs)
|
obs = [None] * len(envs)
|
||||||
agent_state = None
|
agent_state = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user