cleaned up envs

This commit is contained in:
NM512 2023-04-15 23:16:43 +09:00
parent fba87a33e0
commit 1e070a3daf
8 changed files with 507 additions and 431 deletions

View File

@ -122,16 +122,22 @@ defaults:
visual_dmc:
atari:
atari100k:
steps: 4e5
action_repeat: 4
eval_episode_num: 100
stickey: False
lives: unused
noops: 30
resize: opencv
actions: needed
actor_dist: 'onehot'
train_ratio: 1024
imag_gradient: 'reinforce'
time_limit: 108000
precision: 32
debug:
debug: True
pretrain: 1
prefill: 1

View File

@ -16,7 +16,7 @@ sys.path.append(str(pathlib.Path(__file__).parent))
import exploration as expl
import models
import tools
import wrappers
import envs.wrappers as wrappers
import torch
from torch import nn
@ -189,21 +189,29 @@ def make_dataset(episodes, config):
def make_env(config, logger, mode, train_eps, eval_eps):
suite, task = config.task.split("_", 1)
if suite == "dmc":
env = wrappers.DeepMindControl(task, config.action_repeat, config.size)
import envs.dmc as dmc
env = dmc.DeepMindControl(task, config.action_repeat, config.size)
env = wrappers.NormalizeActions(env)
elif suite == "atari":
env = wrappers.Atari(
import envs.atari as atari
env = atari.Atari(
task,
config.action_repeat,
config.size,
grayscale=config.grayscale,
life_done=False and ("train" in mode),
sticky_actions=False,
all_actions=False,
gray=config.grayscale,
noops=config.noops,
lives=config.lives,
sticky=config.stickey,
actions=config.actions,
resize=config.resize,
)
env = wrappers.OneHotAction(env)
elif suite == "dmlab":
env = wrappers.DeepMindLabyrinth(
import envs.dmlab as dmlab
env = dmlab.DeepMindLabyrinth(
task, mode if "train" in mode else "test", config.action_repeat
)
env = wrappers.OneHotAction(env)
@ -326,7 +334,7 @@ def main(config):
print(f"Prefill dataset ({prefill} steps).")
if hasattr(acts, "discrete"):
random_actor = tools.OneHotDist(
torch.zeros_like(torch.Tensor(acts.low)).repeat(config.envs, 1)
torch.zeros(config.num_actions).repeat(config.envs, 1)
)
else:
random_actor = torchd.independent.Independent(

128
envs/atari.py Normal file
View File

@ -0,0 +1,128 @@
import numpy as np
class Atari:
LOCK = None
def __init__(
self, name, action_repeat=4, size=(84, 84), gray=True, noops=0, lives='unused',
sticky=True, actions='all', length=108000, resize='opencv', seed=None):
assert size[0] == size[1]
assert lives in ('unused', 'discount', 'reset'), lives
assert actions in ('all', 'needed'), actions
assert resize in ('opencv', 'pillow'), resize
if self.LOCK is None:
import multiprocessing as mp
mp = mp.get_context('spawn')
self.LOCK = mp.Lock()
self._resize = resize
if self._resize == 'opencv':
import cv2
self._cv2 = cv2
if self._resize == 'pillow':
from PIL import Image
self._image = Image
import gym.envs.atari
if name == 'james_bond':
name = 'jamesbond'
self._repeat = action_repeat
self._size = size
self._gray = gray
self._noops = noops
self._lives = lives
self._sticky = sticky
self._length = length
self._random = np.random.RandomState(seed)
with self.LOCK:
self._env = gym.envs.atari.AtariEnv(
game=name,
obs_type='image',
frameskip=1, repeat_action_probability=0.25 if sticky else 0.0,
full_action_space=(actions == 'all'))
assert self._env.unwrapped.get_action_meanings()[0] == 'NOOP'
shape = self._env.observation_space.shape
self._buffer = [np.zeros(shape, np.uint8) for _ in range(2)]
self._ale = self._env.unwrapped.ale
self._last_lives = None
self._done = True
self._step = 0
@property
def action_space(self):
space = self._env.action_space
space.discrete = True
return space
def step(self, action):
# if action['reset'] or self._done:
# with self.LOCK:
# self._reset()
# self._done = False
# self._step = 0
# return self._obs(0.0, is_first=True)
total = 0.0
dead = False
if len(action.shape) >= 1:
action = np.argmax(action)
for repeat in range(self._repeat):
_, reward, over, info = self._env.step(action)
self._step += 1
total += reward
if repeat == self._repeat - 2:
self._screen(self._buffer[1])
if over:
break
if self._lives != 'unused':
current = self._ale.lives()
if current < self._last_lives:
dead = True
self._last_lives = current
break
if not self._repeat:
self._buffer[1][:] = self._buffer[0][:]
self._screen(self._buffer[0])
self._done = over or (self._length and self._step >= self._length) or dead
return self._obs(
total,
is_last=self._done or (dead and self._lives == 'reset'),
is_terminal=dead or over)
def reset(self):
self._env.reset()
if self._noops:
for _ in range(self._random.randint(self._noops)):
_, _, dead, _ = self._env.step(0)
if dead:
self._env.reset()
self._last_lives = self._ale.lives()
self._screen(self._buffer[0])
self._buffer[1].fill(0)
self._done = False
self._step = 0
obs, reward, is_terminal, _ = self._obs(0.0, is_first=True)
return obs
def _obs(self, reward, is_first=False, is_last=False, is_terminal=False):
np.maximum(self._buffer[0], self._buffer[1], out=self._buffer[0])
image = self._buffer[0]
if image.shape[:2] != self._size:
if self._resize == 'opencv':
image = self._cv2.resize(
image, self._size, interpolation=self._cv2.INTER_AREA)
if self._resize == 'pillow':
image = self._image.fromarray(image)
image = image.resize(self._size, self._image.NEAREST)
image = np.array(image)
if self._gray:
weights = [0.299, 0.587, 1 - (0.299 + 0.587)]
image = np.tensordot(image, weights, (-1, 0)).astype(image.dtype)
image = image[:, :, None]
return {'image':image, 'is_terminal':is_terminal}, reward, is_last, {}
def _screen(self, array):
self._ale.getScreenRGB2(array)
def close(self):
return self._env.close()

64
envs/dmc.py Normal file
View File

@ -0,0 +1,64 @@
import gym
import numpy as np
class DeepMindControl:
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
domain, task = name.split('_', 1)
if domain == 'cup': # Only domain with multiple words.
domain = 'ball_in_cup'
if isinstance(domain, str):
from dm_control import suite
self._env = suite.load(domain, task)
else:
assert task is None
self._env = domain()
self._action_repeat = action_repeat
self._size = size
if camera is None:
camera = dict(quadruped=2).get(domain, 0)
self._camera = camera
@property
def observation_space(self):
spaces = {}
for key, value in self._env.observation_spec().items():
spaces[key] = gym.spaces.Box(
-np.inf, np.inf, value.shape, dtype=np.float32)
spaces['image'] = gym.spaces.Box(
0, 255, self._size + (3,), dtype=np.uint8)
return gym.spaces.Dict(spaces)
@property
def action_space(self):
spec = self._env.action_spec()
return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32)
def step(self, action):
assert np.isfinite(action).all(), action
reward = 0
for _ in range(self._action_repeat):
time_step = self._env.step(action)
reward += time_step.reward or 0
if time_step.last():
break
obs = dict(time_step.observation)
obs['image'] = self.render()
# There is no terminal state in DMC
obs['is_terminal'] = False
done = time_step.last()
info = {'discount': np.array(time_step.discount, np.float32)}
return obs, reward, done, info
def reset(self):
time_step = self._env.reset()
obs = dict(time_step.observation)
obs['image'] = self.render()
obs['is_terminal'] = False
return obs
def render(self, *args, **kwargs):
if kwargs.get('mode', 'rgb_array') != 'rgb_array':
raise ValueError("Only render mode 'rgb_array' is supported.")
return self._env.physics.render(*self._size, camera_id=self._camera)

101
envs/dmlab.py Normal file
View File

@ -0,0 +1,101 @@
import gym
import numpy as np
import deepmind_lab
class DeepMindLabyrinth(object):
ACTION_SET_DEFAULT = (
(0, 0, 0, 1, 0, 0, 0), # Forward
(0, 0, 0, -1, 0, 0, 0), # Backward
(0, 0, -1, 0, 0, 0, 0), # Strafe Left
(0, 0, 1, 0, 0, 0, 0), # Strafe Right
(-20, 0, 0, 0, 0, 0, 0), # Look Left
(20, 0, 0, 0, 0, 0, 0), # Look Right
(-20, 0, 0, 1, 0, 0, 0), # Look Left + Forward
(20, 0, 0, 1, 0, 0, 0), # Look Right + Forward
(0, 0, 0, 0, 1, 0, 0), # Fire
)
ACTION_SET_MEDIUM = (
(0, 0, 0, 1, 0, 0, 0), # Forward
(0, 0, 0, -1, 0, 0, 0), # Backward
(0, 0, -1, 0, 0, 0, 0), # Strafe Left
(0, 0, 1, 0, 0, 0, 0), # Strafe Right
(-20, 0, 0, 0, 0, 0, 0), # Look Left
(20, 0, 0, 0, 0, 0, 0), # Look Right
(0, 0, 0, 0, 0, 0, 0), # Idle.
)
ACTION_SET_SMALL = (
(0, 0, 0, 1, 0, 0, 0), # Forward
(-20, 0, 0, 0, 0, 0, 0), # Look Left
(20, 0, 0, 0, 0, 0, 0), # Look Right
)
def __init__(
self, level, mode, action_repeat=4, render_size=(64, 64),
action_set=ACTION_SET_DEFAULT, level_cache=None, seed=None,
runfiles_path=None):
assert mode in ('train', 'test')
if runfiles_path:
print('Setting DMLab runfiles path:', runfiles_path)
deepmind_lab.set_runfiles_path(runfiles_path)
self._config = {}
self._config['width'] = render_size[0]
self._config['height'] = render_size[1]
self._config['logLevel'] = 'WARN'
if mode == 'test':
self._config['allowHoldOutLevels'] = 'true'
self._config['mixerSeed'] = 0x600D5EED
self._action_repeat = action_repeat
self._random = np.random.RandomState(seed)
self._env = deepmind_lab.Lab(
level='contributed/dmlab30/'+level,
observations=['RGB_INTERLEAVED'],
config={k: str(v) for k, v in self._config.items()},
level_cache=level_cache)
self._action_set = action_set
self._last_image = None
self._done = True
@property
def observation_space(self):
shape = (self._config['height'], self._config['width'], 3)
space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
return gym.spaces.Dict({'image': space})
@property
def action_space(self):
return gym.spaces.Discrete(len(self._action_set))
def reset(self):
self._done = False
self._env.reset(seed=self._random.randint(0, 2 ** 31 - 1))
obs = self._get_obs()
return obs
def step(self, action):
raw_action = np.array(self._action_set[action], np.intc)
reward = self._env.step(raw_action, num_steps=self._action_repeat)
self._done = not self._env.is_running()
obs = self._get_obs()
return obs, reward, self._done, {}
def render(self, *args, **kwargs):
if kwargs.get('mode', 'rgb_array') != 'rgb_array':
raise ValueError("Only render mode 'rgb_array' is supported.")
del args # Unused
del kwargs # Unused
return self._last_image
def close(self):
self._env.close()
def _get_obs(self):
if self._done:
image = 0 * self._last_image
else:
image = self._env.observations()['RGB_INTERLEAVED']
self._last_image = image
return {'image': image}

188
envs/wrappers.py Normal file
View File

@ -0,0 +1,188 @@
import gym
import numpy as np
class CollectDataset:
def __init__(self, env, callbacks=None, precision=32):
self._env = env
self._callbacks = callbacks or ()
self._precision = precision
self._episode = None
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
obs, reward, done, info = self._env.step(action)
obs = {k: self._convert(v) for k, v in obs.items()}
transition = obs.copy()
if isinstance(action, dict):
transition.update(action)
else:
transition['action'] = action
transition['reward'] = reward
transition['discount'] = info.get('discount', np.array(1 - float(done)))
self._episode.append(transition)
if done:
for key, value in self._episode[1].items():
if key not in self._episode[0]:
self._episode[0][key] = 0 * value
episode = {k: [t[k] for t in self._episode] for k in self._episode[0]}
episode = {k: self._convert(v) for k, v in episode.items()}
info['episode'] = episode
for callback in self._callbacks:
callback(episode)
return obs, reward, done, info
def reset(self):
obs = self._env.reset()
transition = obs.copy()
# Missing keys will be filled with a zeroed out version of the first
# transition, because we do not know what action information the agent will
# pass yet.
transition['reward'] = 0.0
transition['discount'] = 1.0
self._episode = [transition]
return obs
def _convert(self, value):
value = np.array(value)
if np.issubdtype(value.dtype, np.floating):
dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision]
elif np.issubdtype(value.dtype, np.signedinteger):
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
elif np.issubdtype(value.dtype, np.uint8):
dtype = np.uint8
elif np.issubdtype(value.dtype, np.bool):
dtype = np.bool
else:
raise NotImplementedError(value.dtype)
return value.astype(dtype)
class TimeLimit:
def __init__(self, env, duration):
self._env = env
self._duration = duration
self._step = None
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
assert self._step is not None, 'Must reset environment.'
obs, reward, done, info = self._env.step(action)
self._step += 1
if self._step >= self._duration:
done = True
if 'discount' not in info:
info['discount'] = np.array(1.0).astype(np.float32)
self._step = None
return obs, reward, done, info
def reset(self):
self._step = 0
return self._env.reset()
class NormalizeActions:
def __init__(self, env):
self._env = env
self._mask = np.logical_and(
np.isfinite(env.action_space.low),
np.isfinite(env.action_space.high))
self._low = np.where(self._mask, env.action_space.low, -1)
self._high = np.where(self._mask, env.action_space.high, 1)
def __getattr__(self, name):
return getattr(self._env, name)
@property
def action_space(self):
low = np.where(self._mask, -np.ones_like(self._low), self._low)
high = np.where(self._mask, np.ones_like(self._low), self._high)
return gym.spaces.Box(low, high, dtype=np.float32)
def step(self, action):
original = (action + 1) / 2 * (self._high - self._low) + self._low
original = np.where(self._mask, original, action)
return self._env.step(original)
class OneHotAction:
def __init__(self, env):
assert isinstance(env.action_space, gym.spaces.Discrete)
self._env = env
self._random = np.random.RandomState()
def __getattr__(self, name):
return getattr(self._env, name)
@property
def action_space(self):
shape = (self._env.action_space.n,)
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
space.sample = self._sample_action
space.discrete = True
return space
def step(self, action):
index = np.argmax(action).astype(int)
reference = np.zeros_like(action)
reference[index] = 1
if not np.allclose(reference, action):
raise ValueError(f'Invalid one-hot action:\n{action}')
return self._env.step(index)
def reset(self):
return self._env.reset()
def _sample_action(self):
actions = self._env.action_space.n
index = self._random.randint(0, actions)
reference = np.zeros(actions, dtype=np.float32)
reference[index] = 1.0
return reference
class RewardObs:
def __init__(self, env):
self._env = env
def __getattr__(self, name):
return getattr(self._env, name)
@property
def observation_space(self):
spaces = self._env.observation_space.spaces
assert 'reward' not in spaces
spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32)
return gym.spaces.Dict(spaces)
def step(self, action):
obs, reward, done, info = self._env.step(action)
obs['reward'] = reward
return obs, reward, done, info
def reset(self):
obs = self._env.reset()
obs['reward'] = 0.0
return obs
class SelectAction:
def __init__(self, env, key):
self._env = env
self._key = key
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
return self._env.step(action[self._key])

View File

@ -5,8 +5,8 @@ tensorboard==2.5.0
pandas==1.2.4
matplotlib==3.4.1
ruamel.yaml==0.17.4
gym[atari]==0.17.0
moviepy==1.0.3
einops==0.3.0
protobuf==3.20.0
gym==0.19.0
dm_control==1.0.9

View File

@ -1,419 +0,0 @@
import threading
import gym
import numpy as np
class DeepMindLabyrinth(object):
ACTION_SET_DEFAULT = (
(0, 0, 0, 1, 0, 0, 0), # Forward
(0, 0, 0, -1, 0, 0, 0), # Backward
(0, 0, -1, 0, 0, 0, 0), # Strafe Left
(0, 0, 1, 0, 0, 0, 0), # Strafe Right
(-20, 0, 0, 0, 0, 0, 0), # Look Left
(20, 0, 0, 0, 0, 0, 0), # Look Right
(-20, 0, 0, 1, 0, 0, 0), # Look Left + Forward
(20, 0, 0, 1, 0, 0, 0), # Look Right + Forward
(0, 0, 0, 0, 1, 0, 0), # Fire
)
ACTION_SET_MEDIUM = (
(0, 0, 0, 1, 0, 0, 0), # Forward
(0, 0, 0, -1, 0, 0, 0), # Backward
(0, 0, -1, 0, 0, 0, 0), # Strafe Left
(0, 0, 1, 0, 0, 0, 0), # Strafe Right
(-20, 0, 0, 0, 0, 0, 0), # Look Left
(20, 0, 0, 0, 0, 0, 0), # Look Right
(0, 0, 0, 0, 0, 0, 0), # Idle.
)
ACTION_SET_SMALL = (
(0, 0, 0, 1, 0, 0, 0), # Forward
(-20, 0, 0, 0, 0, 0, 0), # Look Left
(20, 0, 0, 0, 0, 0, 0), # Look Right
)
def __init__(
self,
level,
mode,
action_repeat=4,
render_size=(64, 64),
action_set=ACTION_SET_DEFAULT,
level_cache=None,
seed=None,
runfiles_path=None,
):
assert mode in ("train", "test")
import deepmind_lab
if runfiles_path:
print("Setting DMLab runfiles path:", runfiles_path)
deepmind_lab.set_runfiles_path(runfiles_path)
self._config = {}
self._config["width"] = render_size[0]
self._config["height"] = render_size[1]
self._config["logLevel"] = "WARN"
if mode == "test":
self._config["allowHoldOutLevels"] = "true"
self._config["mixerSeed"] = 0x600D5EED
self._action_repeat = action_repeat
self._random = np.random.RandomState(seed)
self._env = deepmind_lab.Lab(
level="contributed/dmlab30/" + level,
observations=["RGB_INTERLEAVED"],
config={k: str(v) for k, v in self._config.items()},
level_cache=level_cache,
)
self._action_set = action_set
self._last_image = None
self._done = True
@property
def observation_space(self):
shape = (self._config["height"], self._config["width"], 3)
space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
return gym.spaces.Dict({"image": space})
@property
def action_space(self):
return gym.spaces.Discrete(len(self._action_set))
def reset(self):
self._done = False
self._env.reset(seed=self._random.randint(0, 2**31 - 1))
obs = self._get_obs()
return obs
def step(self, action):
raw_action = np.array(self._action_set[action], np.intc)
reward = self._env.step(raw_action, num_steps=self._action_repeat)
self._done = not self._env.is_running()
obs = self._get_obs()
return obs, reward, self._done, {}
def render(self, *args, **kwargs):
if kwargs.get("mode", "rgb_array") != "rgb_array":
raise ValueError("Only render mode 'rgb_array' is supported.")
del args # Unused
del kwargs # Unused
return self._last_image
def close(self):
self._env.close()
def _get_obs(self):
if self._done:
image = 0 * self._last_image
else:
image = self._env.observations()["RGB_INTERLEAVED"]
self._last_image = image
return {"image": image}
class DeepMindControl:
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
domain, task = name.split("_", 1)
if domain == "cup": # Only domain with multiple words.
domain = "ball_in_cup"
if isinstance(domain, str):
from dm_control import suite
self._env = suite.load(domain, task)
else:
assert task is None
self._env = domain()
self._action_repeat = action_repeat
self._size = size
if camera is None:
camera = dict(quadruped=2).get(domain, 0)
self._camera = camera
@property
def observation_space(self):
spaces = {}
for key, value in self._env.observation_spec().items():
spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, dtype=np.float32)
spaces["image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8)
return gym.spaces.Dict(spaces)
@property
def action_space(self):
spec = self._env.action_spec()
return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32)
def step(self, action):
assert np.isfinite(action).all(), action
reward = 0
for _ in range(self._action_repeat):
time_step = self._env.step(action)
reward += time_step.reward or 0
if time_step.last():
break
obs = dict(time_step.observation)
obs["image"] = self.render()
done = time_step.last()
info = {"discount": np.array(time_step.discount, np.float32)}
return obs, reward, done, info
def reset(self):
time_step = self._env.reset()
obs = dict(time_step.observation)
obs["image"] = self.render()
return obs
def render(self, *args, **kwargs):
if kwargs.get("mode", "rgb_array") != "rgb_array":
raise ValueError("Only render mode 'rgb_array' is supported.")
return self._env.physics.render(*self._size, camera_id=self._camera)
class Atari:
LOCK = threading.Lock()
def __init__(
self,
name,
action_repeat=4,
size=(84, 84),
grayscale=True,
noops=30,
life_done=False,
sticky_actions=True,
all_actions=False,
):
assert size[0] == size[1]
import gym.wrappers
import gym.envs.atari
if name == "james_bond":
name = "jamesbond"
with self.LOCK:
env = gym.envs.atari.AtariEnv(
game=name,
obs_type="image",
frameskip=1,
repeat_action_probability=0.25 if sticky_actions else 0.0,
full_action_space=all_actions,
)
# Avoid unnecessary rendering in inner env.
env._get_obs = lambda: None
# Tell wrapper that the inner env has no action repeat.
env.spec = gym.envs.registration.EnvSpec("NoFrameskip-v0")
env = gym.wrappers.AtariPreprocessing(
env, noops, action_repeat, size[0], life_done, grayscale
)
self._env = env
self._grayscale = grayscale
@property
def observation_space(self):
return gym.spaces.Dict(
{
"image": self._env.observation_space,
"ram": gym.spaces.Box(0, 255, (128,), np.uint8),
}
)
@property
def action_space(self):
return self._env.action_space
def close(self):
return self._env.close()
def reset(self):
with self.LOCK:
image = self._env.reset()
if self._grayscale:
image = image[..., None]
obs = {"image": image, "ram": self._env.env._get_ram()}
return obs
def step(self, action):
image, reward, done, info = self._env.step(action)
if self._grayscale:
image = image[..., None]
obs = {"image": image, "ram": self._env.env._get_ram()}
return obs, reward, done, info
def render(self, mode):
return self._env.render(mode)
class CollectDataset:
def __init__(self, env, callbacks=None, precision=32):
self._env = env
self._callbacks = callbacks or ()
self._precision = precision
self._episode = None
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
obs, reward, done, info = self._env.step(action)
obs = {k: self._convert(v) for k, v in obs.items()}
transition = obs.copy()
if isinstance(action, dict):
transition.update(action)
else:
transition["action"] = action
transition["reward"] = reward
transition["discount"] = info.get("discount", np.array(1 - float(done)))
self._episode.append(transition)
if done:
for key, value in self._episode[1].items():
if key not in self._episode[0]:
self._episode[0][key] = 0 * value
episode = {k: [t[k] for t in self._episode] for k in self._episode[0]}
episode = {k: self._convert(v) for k, v in episode.items()}
info["episode"] = episode
for callback in self._callbacks:
callback(episode)
return obs, reward, done, info
def reset(self):
obs = self._env.reset()
transition = obs.copy()
# Missing keys will be filled with a zeroed out version of the first
# transition, because we do not know what action information the agent will
# pass yet.
transition["reward"] = 0.0
transition["discount"] = 1.0
self._episode = [transition]
return obs
def _convert(self, value):
value = np.array(value)
if np.issubdtype(value.dtype, np.floating):
dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision]
elif np.issubdtype(value.dtype, np.signedinteger):
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
elif np.issubdtype(value.dtype, np.uint8):
dtype = np.uint8
else:
raise NotImplementedError(value.dtype)
return value.astype(dtype)
class TimeLimit:
def __init__(self, env, duration):
self._env = env
self._duration = duration
self._step = None
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
assert self._step is not None, "Must reset environment."
obs, reward, done, info = self._env.step(action)
self._step += 1
if self._step >= self._duration:
done = True
if "discount" not in info:
info["discount"] = np.array(1.0).astype(np.float32)
self._step = None
return obs, reward, done, info
def reset(self):
self._step = 0
return self._env.reset()
class NormalizeActions:
def __init__(self, env):
self._env = env
self._mask = np.logical_and(
np.isfinite(env.action_space.low), np.isfinite(env.action_space.high)
)
self._low = np.where(self._mask, env.action_space.low, -1)
self._high = np.where(self._mask, env.action_space.high, 1)
def __getattr__(self, name):
return getattr(self._env, name)
@property
def action_space(self):
low = np.where(self._mask, -np.ones_like(self._low), self._low)
high = np.where(self._mask, np.ones_like(self._low), self._high)
return gym.spaces.Box(low, high, dtype=np.float32)
def step(self, action):
original = (action + 1) / 2 * (self._high - self._low) + self._low
original = np.where(self._mask, original, action)
return self._env.step(original)
class OneHotAction:
def __init__(self, env):
assert isinstance(env.action_space, gym.spaces.Discrete)
self._env = env
self._random = np.random.RandomState()
def __getattr__(self, name):
return getattr(self._env, name)
@property
def action_space(self):
shape = (self._env.action_space.n,)
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
space.sample = self._sample_action
space.discrete = True
return space
def step(self, action):
index = np.argmax(action).astype(int)
reference = np.zeros_like(action)
reference[index] = 1
if not np.allclose(reference, action):
raise ValueError(f"Invalid one-hot action:\n{action}")
return self._env.step(index)
def reset(self):
return self._env.reset()
def _sample_action(self):
actions = self._env.action_space.n
index = self._random.randint(0, actions)
reference = np.zeros(actions, dtype=np.float32)
reference[index] = 1.0
return reference
class RewardObs:
def __init__(self, env):
self._env = env
def __getattr__(self, name):
return getattr(self._env, name)
@property
def observation_space(self):
spaces = self._env.observation_space.spaces
assert "reward" not in spaces
spaces["reward"] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32)
return gym.spaces.Dict(spaces)
def step(self, action):
obs, reward, done, info = self._env.step(action)
obs["reward"] = reward
return obs, reward, done, info
def reset(self):
obs = self._env.reset()
obs["reward"] = 0.0
return obs
class SelectAction:
def __init__(self, env, key):
self._env = env
self._key = key
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
return self._env.step(action[self._key])