applied formatter to envs
This commit is contained in:
parent
628b856c63
commit
6f0e6c6963
@ -2,30 +2,44 @@ import numpy as np
|
||||
|
||||
|
||||
class Atari:
|
||||
|
||||
LOCK = None
|
||||
|
||||
def __init__(
|
||||
self, name, action_repeat=4, size=(84, 84), gray=True, noops=0, lives='unused',
|
||||
sticky=True, actions='all', length=108000, resize='opencv', seed=None):
|
||||
self,
|
||||
name,
|
||||
action_repeat=4,
|
||||
size=(84, 84),
|
||||
gray=True,
|
||||
noops=0,
|
||||
lives="unused",
|
||||
sticky=True,
|
||||
actions="all",
|
||||
length=108000,
|
||||
resize="opencv",
|
||||
seed=None,
|
||||
):
|
||||
assert size[0] == size[1]
|
||||
assert lives in ('unused', 'discount', 'reset'), lives
|
||||
assert actions in ('all', 'needed'), actions
|
||||
assert resize in ('opencv', 'pillow'), resize
|
||||
assert lives in ("unused", "discount", "reset"), lives
|
||||
assert actions in ("all", "needed"), actions
|
||||
assert resize in ("opencv", "pillow"), resize
|
||||
if self.LOCK is None:
|
||||
import multiprocessing as mp
|
||||
mp = mp.get_context('spawn')
|
||||
|
||||
mp = mp.get_context("spawn")
|
||||
self.LOCK = mp.Lock()
|
||||
self._resize = resize
|
||||
if self._resize == 'opencv':
|
||||
if self._resize == "opencv":
|
||||
import cv2
|
||||
|
||||
self._cv2 = cv2
|
||||
if self._resize == 'pillow':
|
||||
if self._resize == "pillow":
|
||||
from PIL import Image
|
||||
|
||||
self._image = Image
|
||||
import gym.envs.atari
|
||||
if name == 'james_bond':
|
||||
name = 'jamesbond'
|
||||
|
||||
if name == "james_bond":
|
||||
name = "jamesbond"
|
||||
self._repeat = action_repeat
|
||||
self._size = size
|
||||
self._gray = gray
|
||||
@ -37,10 +51,12 @@ class Atari:
|
||||
with self.LOCK:
|
||||
self._env = gym.envs.atari.AtariEnv(
|
||||
game=name,
|
||||
obs_type='image',
|
||||
frameskip=1, repeat_action_probability=0.25 if sticky else 0.0,
|
||||
full_action_space=(actions == 'all'))
|
||||
assert self._env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
||||
obs_type="image",
|
||||
frameskip=1,
|
||||
repeat_action_probability=0.25 if sticky else 0.0,
|
||||
full_action_space=(actions == "all"),
|
||||
)
|
||||
assert self._env.unwrapped.get_action_meanings()[0] == "NOOP"
|
||||
shape = self._env.observation_space.shape
|
||||
self._buffer = [np.zeros(shape, np.uint8) for _ in range(2)]
|
||||
self._ale = self._env.unwrapped.ale
|
||||
@ -73,7 +89,7 @@ class Atari:
|
||||
self._screen(self._buffer[1])
|
||||
if over:
|
||||
break
|
||||
if self._lives != 'unused':
|
||||
if self._lives != "unused":
|
||||
current = self._ale.lives()
|
||||
if current < self._last_lives:
|
||||
dead = True
|
||||
@ -85,8 +101,9 @@ class Atari:
|
||||
self._done = over or (self._length and self._step >= self._length) or dead
|
||||
return self._obs(
|
||||
total,
|
||||
is_last=self._done or (dead and self._lives == 'reset'),
|
||||
is_terminal=dead or over)
|
||||
is_last=self._done or (dead and self._lives == "reset"),
|
||||
is_terminal=dead or over,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self._env.reset()
|
||||
@ -108,10 +125,11 @@ class Atari:
|
||||
np.maximum(self._buffer[0], self._buffer[1], out=self._buffer[0])
|
||||
image = self._buffer[0]
|
||||
if image.shape[:2] != self._size:
|
||||
if self._resize == 'opencv':
|
||||
if self._resize == "opencv":
|
||||
image = self._cv2.resize(
|
||||
image, self._size, interpolation=self._cv2.INTER_AREA)
|
||||
if self._resize == 'pillow':
|
||||
image, self._size, interpolation=self._cv2.INTER_AREA
|
||||
)
|
||||
if self._resize == "pillow":
|
||||
image = self._image.fromarray(image)
|
||||
image = image.resize(self._size, self._image.NEAREST)
|
||||
image = np.array(image)
|
||||
@ -119,7 +137,7 @@ class Atari:
|
||||
weights = [0.299, 0.587, 1 - (0.299 + 0.587)]
|
||||
image = np.tensordot(image, weights, (-1, 0)).astype(image.dtype)
|
||||
image = image[:, :, None]
|
||||
return {'image':image, 'is_terminal':is_terminal}, reward, is_last, {}
|
||||
return {"image": image, "is_terminal": is_terminal}, reward, is_last, {}
|
||||
|
||||
def _screen(self, array):
|
||||
self._ale.getScreenRGB2(array)
|
||||
|
26
envs/dmc.py
26
envs/dmc.py
@ -3,13 +3,13 @@ import numpy as np
|
||||
|
||||
|
||||
class DeepMindControl:
|
||||
|
||||
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
|
||||
domain, task = name.split('_', 1)
|
||||
if domain == 'cup': # Only domain with multiple words.
|
||||
domain = 'ball_in_cup'
|
||||
domain, task = name.split("_", 1)
|
||||
if domain == "cup": # Only domain with multiple words.
|
||||
domain = "ball_in_cup"
|
||||
if isinstance(domain, str):
|
||||
from dm_control import suite
|
||||
|
||||
self._env = suite.load(domain, task)
|
||||
else:
|
||||
assert task is None
|
||||
@ -24,10 +24,8 @@ class DeepMindControl:
|
||||
def observation_space(self):
|
||||
spaces = {}
|
||||
for key, value in self._env.observation_spec().items():
|
||||
spaces[key] = gym.spaces.Box(
|
||||
-np.inf, np.inf, value.shape, dtype=np.float32)
|
||||
spaces['image'] = gym.spaces.Box(
|
||||
0, 255, self._size + (3,), dtype=np.uint8)
|
||||
spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, dtype=np.float32)
|
||||
spaces["image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8)
|
||||
return gym.spaces.Dict(spaces)
|
||||
|
||||
@property
|
||||
@ -44,21 +42,21 @@ class DeepMindControl:
|
||||
if time_step.last():
|
||||
break
|
||||
obs = dict(time_step.observation)
|
||||
obs['image'] = self.render()
|
||||
obs["image"] = self.render()
|
||||
# There is no terminal state in DMC
|
||||
obs['is_terminal'] = False
|
||||
obs["is_terminal"] = False
|
||||
done = time_step.last()
|
||||
info = {'discount': np.array(time_step.discount, np.float32)}
|
||||
info = {"discount": np.array(time_step.discount, np.float32)}
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
time_step = self._env.reset()
|
||||
obs = dict(time_step.observation)
|
||||
obs['image'] = self.render()
|
||||
obs['is_terminal'] = False
|
||||
obs["image"] = self.render()
|
||||
obs["is_terminal"] = False
|
||||
return obs
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
if kwargs.get('mode', 'rgb_array') != 'rgb_array':
|
||||
if kwargs.get("mode", "rgb_array") != "rgb_array":
|
||||
raise ValueError("Only render mode 'rgb_array' is supported.")
|
||||
return self._env.physics.render(*self._size, camera_id=self._camera)
|
||||
|
@ -4,7 +4,6 @@ import deepmind_lab
|
||||
|
||||
|
||||
class DeepMindLabyrinth(object):
|
||||
|
||||
ACTION_SET_DEFAULT = (
|
||||
(0, 0, 0, 1, 0, 0, 0), # Forward
|
||||
(0, 0, 0, -1, 0, 0, 0), # Backward
|
||||
@ -34,36 +33,44 @@ class DeepMindLabyrinth(object):
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, level, mode, action_repeat=4, render_size=(64, 64),
|
||||
action_set=ACTION_SET_DEFAULT, level_cache=None, seed=None,
|
||||
runfiles_path=None):
|
||||
assert mode in ('train', 'test')
|
||||
self,
|
||||
level,
|
||||
mode,
|
||||
action_repeat=4,
|
||||
render_size=(64, 64),
|
||||
action_set=ACTION_SET_DEFAULT,
|
||||
level_cache=None,
|
||||
seed=None,
|
||||
runfiles_path=None,
|
||||
):
|
||||
assert mode in ("train", "test")
|
||||
if runfiles_path:
|
||||
print('Setting DMLab runfiles path:', runfiles_path)
|
||||
print("Setting DMLab runfiles path:", runfiles_path)
|
||||
deepmind_lab.set_runfiles_path(runfiles_path)
|
||||
self._config = {}
|
||||
self._config['width'] = render_size[0]
|
||||
self._config['height'] = render_size[1]
|
||||
self._config['logLevel'] = 'WARN'
|
||||
if mode == 'test':
|
||||
self._config['allowHoldOutLevels'] = 'true'
|
||||
self._config['mixerSeed'] = 0x600D5EED
|
||||
self._config["width"] = render_size[0]
|
||||
self._config["height"] = render_size[1]
|
||||
self._config["logLevel"] = "WARN"
|
||||
if mode == "test":
|
||||
self._config["allowHoldOutLevels"] = "true"
|
||||
self._config["mixerSeed"] = 0x600D5EED
|
||||
self._action_repeat = action_repeat
|
||||
self._random = np.random.RandomState(seed)
|
||||
self._env = deepmind_lab.Lab(
|
||||
level='contributed/dmlab30/'+level,
|
||||
observations=['RGB_INTERLEAVED'],
|
||||
level="contributed/dmlab30/" + level,
|
||||
observations=["RGB_INTERLEAVED"],
|
||||
config={k: str(v) for k, v in self._config.items()},
|
||||
level_cache=level_cache)
|
||||
level_cache=level_cache,
|
||||
)
|
||||
self._action_set = action_set
|
||||
self._last_image = None
|
||||
self._done = True
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
shape = (self._config['height'], self._config['width'], 3)
|
||||
shape = (self._config["height"], self._config["width"], 3)
|
||||
space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
|
||||
return gym.spaces.Dict({'image': space})
|
||||
return gym.spaces.Dict({"image": space})
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
@ -71,7 +78,7 @@ class DeepMindLabyrinth(object):
|
||||
|
||||
def reset(self):
|
||||
self._done = False
|
||||
self._env.reset(seed=self._random.randint(0, 2 ** 31 - 1))
|
||||
self._env.reset(seed=self._random.randint(0, 2**31 - 1))
|
||||
obs = self._get_obs()
|
||||
return obs
|
||||
|
||||
@ -83,7 +90,7 @@ class DeepMindLabyrinth(object):
|
||||
return obs, reward, self._done, {}
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
if kwargs.get('mode', 'rgb_array') != 'rgb_array':
|
||||
if kwargs.get("mode", "rgb_array") != "rgb_array":
|
||||
raise ValueError("Only render mode 'rgb_array' is supported.")
|
||||
del args # Unused
|
||||
del kwargs # Unused
|
||||
@ -96,6 +103,6 @@ class DeepMindLabyrinth(object):
|
||||
if self._done:
|
||||
image = 0 * self._last_image
|
||||
else:
|
||||
image = self._env.observations()['RGB_INTERLEAVED']
|
||||
image = self._env.observations()["RGB_INTERLEAVED"]
|
||||
self._last_image = image
|
||||
return {'image': image}
|
||||
return {"image": image}
|
||||
|
@ -3,7 +3,6 @@ import numpy as np
|
||||
|
||||
|
||||
class CollectDataset:
|
||||
|
||||
def __init__(self, env, callbacks=None, precision=32):
|
||||
self._env = env
|
||||
self._callbacks = callbacks or ()
|
||||
@ -20,9 +19,9 @@ class CollectDataset:
|
||||
if isinstance(action, dict):
|
||||
transition.update(action)
|
||||
else:
|
||||
transition['action'] = action
|
||||
transition['reward'] = reward
|
||||
transition['discount'] = info.get('discount', np.array(1 - float(done)))
|
||||
transition["action"] = action
|
||||
transition["reward"] = reward
|
||||
transition["discount"] = info.get("discount", np.array(1 - float(done)))
|
||||
self._episode.append(transition)
|
||||
if done:
|
||||
for key, value in self._episode[1].items():
|
||||
@ -30,7 +29,7 @@ class CollectDataset:
|
||||
self._episode[0][key] = 0 * value
|
||||
episode = {k: [t[k] for t in self._episode] for k in self._episode[0]}
|
||||
episode = {k: self._convert(v) for k, v in episode.items()}
|
||||
info['episode'] = episode
|
||||
info["episode"] = episode
|
||||
for callback in self._callbacks:
|
||||
callback(episode)
|
||||
return obs, reward, done, info
|
||||
@ -41,8 +40,8 @@ class CollectDataset:
|
||||
# Missing keys will be filled with a zeroed out version of the first
|
||||
# transition, because we do not know what action information the agent will
|
||||
# pass yet.
|
||||
transition['reward'] = 0.0
|
||||
transition['discount'] = 1.0
|
||||
transition["reward"] = 0.0
|
||||
transition["discount"] = 1.0
|
||||
self._episode = [transition]
|
||||
return obs
|
||||
|
||||
@ -62,7 +61,6 @@ class CollectDataset:
|
||||
|
||||
|
||||
class TimeLimit:
|
||||
|
||||
def __init__(self, env, duration):
|
||||
self._env = env
|
||||
self._duration = duration
|
||||
@ -72,13 +70,13 @@ class TimeLimit:
|
||||
return getattr(self._env, name)
|
||||
|
||||
def step(self, action):
|
||||
assert self._step is not None, 'Must reset environment.'
|
||||
assert self._step is not None, "Must reset environment."
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
self._step += 1
|
||||
if self._step >= self._duration:
|
||||
done = True
|
||||
if 'discount' not in info:
|
||||
info['discount'] = np.array(1.0).astype(np.float32)
|
||||
if "discount" not in info:
|
||||
info["discount"] = np.array(1.0).astype(np.float32)
|
||||
self._step = None
|
||||
return obs, reward, done, info
|
||||
|
||||
@ -88,12 +86,11 @@ class TimeLimit:
|
||||
|
||||
|
||||
class NormalizeActions:
|
||||
|
||||
def __init__(self, env):
|
||||
self._env = env
|
||||
self._mask = np.logical_and(
|
||||
np.isfinite(env.action_space.low),
|
||||
np.isfinite(env.action_space.high))
|
||||
np.isfinite(env.action_space.low), np.isfinite(env.action_space.high)
|
||||
)
|
||||
self._low = np.where(self._mask, env.action_space.low, -1)
|
||||
self._high = np.where(self._mask, env.action_space.high, 1)
|
||||
|
||||
@ -113,7 +110,6 @@ class NormalizeActions:
|
||||
|
||||
|
||||
class OneHotAction:
|
||||
|
||||
def __init__(self, env):
|
||||
assert isinstance(env.action_space, gym.spaces.Discrete)
|
||||
self._env = env
|
||||
@ -135,7 +131,7 @@ class OneHotAction:
|
||||
reference = np.zeros_like(action)
|
||||
reference[index] = 1
|
||||
if not np.allclose(reference, action):
|
||||
raise ValueError(f'Invalid one-hot action:\n{action}')
|
||||
raise ValueError(f"Invalid one-hot action:\n{action}")
|
||||
return self._env.step(index)
|
||||
|
||||
def reset(self):
|
||||
@ -150,7 +146,6 @@ class OneHotAction:
|
||||
|
||||
|
||||
class RewardObs:
|
||||
|
||||
def __init__(self, env):
|
||||
self._env = env
|
||||
|
||||
@ -160,23 +155,22 @@ class RewardObs:
|
||||
@property
|
||||
def observation_space(self):
|
||||
spaces = self._env.observation_space.spaces
|
||||
assert 'reward' not in spaces
|
||||
spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32)
|
||||
assert "reward" not in spaces
|
||||
spaces["reward"] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32)
|
||||
return gym.spaces.Dict(spaces)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
obs['reward'] = reward
|
||||
obs["reward"] = reward
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
obs = self._env.reset()
|
||||
obs['reward'] = 0.0
|
||||
obs["reward"] = 0.0
|
||||
return obs
|
||||
|
||||
|
||||
class SelectAction:
|
||||
|
||||
def __init__(self, env, key):
|
||||
self._env = env
|
||||
self._key = key
|
||||
|
Loading…
x
Reference in New Issue
Block a user