dreamerv3-torch/envs/wrappers.py

183 lines
5.7 KiB
Python
Raw Normal View History

2023-04-15 23:16:43 +09:00
import gym
import numpy as np
class CollectDataset:
2023-04-23 22:52:30 +09:00
def __init__(self, env, callbacks=None, precision=32):
self._env = env
self._callbacks = callbacks or ()
self._precision = precision
self._episode = None
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
obs, reward, done, info = self._env.step(action)
obs = {k: self._convert(v) for k, v in obs.items()}
transition = obs.copy()
if isinstance(action, dict):
transition.update(action)
else:
transition["action"] = action
transition["reward"] = reward
transition["discount"] = info.get("discount", np.array(1 - float(done)))
self._episode.append(transition)
if done:
for key, value in self._episode[1].items():
if key not in self._episode[0]:
self._episode[0][key] = 0 * value
episode = {k: [t[k] for t in self._episode] for k in self._episode[0]}
episode = {k: self._convert(v) for k, v in episode.items()}
info["episode"] = episode
for callback in self._callbacks:
callback(episode)
return obs, reward, done, info
def reset(self):
obs = self._env.reset()
transition = obs.copy()
# Missing keys will be filled with a zeroed out version of the first
# transition, because we do not know what action information the agent will
# pass yet.
transition["reward"] = 0.0
transition["discount"] = 1.0
self._episode = [transition]
return obs
def _convert(self, value):
value = np.array(value)
if np.issubdtype(value.dtype, np.floating):
dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision]
elif np.issubdtype(value.dtype, np.signedinteger):
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
elif np.issubdtype(value.dtype, np.uint8):
dtype = np.uint8
elif np.issubdtype(value.dtype, np.bool):
dtype = np.bool
else:
raise NotImplementedError(value.dtype)
return value.astype(dtype)
2023-04-15 23:16:43 +09:00
class TimeLimit:
2023-04-23 22:52:30 +09:00
def __init__(self, env, duration):
self._env = env
self._duration = duration
self._step = None
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
assert self._step is not None, "Must reset environment."
obs, reward, done, info = self._env.step(action)
self._step += 1
if self._step >= self._duration:
done = True
if "discount" not in info:
info["discount"] = np.array(1.0).astype(np.float32)
self._step = None
return obs, reward, done, info
def reset(self):
self._step = 0
return self._env.reset()
2023-04-15 23:16:43 +09:00
class NormalizeActions:
2023-04-23 22:52:30 +09:00
def __init__(self, env):
self._env = env
self._mask = np.logical_and(
np.isfinite(env.action_space.low), np.isfinite(env.action_space.high)
)
self._low = np.where(self._mask, env.action_space.low, -1)
self._high = np.where(self._mask, env.action_space.high, 1)
2023-04-15 23:16:43 +09:00
2023-04-23 22:52:30 +09:00
def __getattr__(self, name):
return getattr(self._env, name)
2023-04-15 23:16:43 +09:00
2023-04-23 22:52:30 +09:00
@property
def action_space(self):
low = np.where(self._mask, -np.ones_like(self._low), self._low)
high = np.where(self._mask, np.ones_like(self._low), self._high)
return gym.spaces.Box(low, high, dtype=np.float32)
2023-04-15 23:16:43 +09:00
2023-04-23 22:52:30 +09:00
def step(self, action):
original = (action + 1) / 2 * (self._high - self._low) + self._low
original = np.where(self._mask, original, action)
return self._env.step(original)
2023-04-15 23:16:43 +09:00
class OneHotAction:
2023-04-23 22:52:30 +09:00
def __init__(self, env):
assert isinstance(env.action_space, gym.spaces.Discrete)
self._env = env
self._random = np.random.RandomState()
def __getattr__(self, name):
return getattr(self._env, name)
@property
def action_space(self):
shape = (self._env.action_space.n,)
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
space.sample = self._sample_action
space.discrete = True
return space
def step(self, action):
index = np.argmax(action).astype(int)
reference = np.zeros_like(action)
reference[index] = 1
if not np.allclose(reference, action):
raise ValueError(f"Invalid one-hot action:\n{action}")
return self._env.step(index)
def reset(self):
return self._env.reset()
def _sample_action(self):
actions = self._env.action_space.n
index = self._random.randint(0, actions)
reference = np.zeros(actions, dtype=np.float32)
reference[index] = 1.0
return reference
2023-04-15 23:16:43 +09:00
class RewardObs:
2023-04-23 22:52:30 +09:00
def __init__(self, env):
self._env = env
2023-04-15 23:16:43 +09:00
2023-04-23 22:52:30 +09:00
def __getattr__(self, name):
return getattr(self._env, name)
2023-04-15 23:16:43 +09:00
2023-04-23 22:52:30 +09:00
@property
def observation_space(self):
spaces = self._env.observation_space.spaces
assert "reward" not in spaces
spaces["reward"] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32)
return gym.spaces.Dict(spaces)
2023-04-15 23:16:43 +09:00
2023-04-23 22:52:30 +09:00
def step(self, action):
obs, reward, done, info = self._env.step(action)
obs["reward"] = reward
return obs, reward, done, info
2023-04-15 23:16:43 +09:00
2023-04-23 22:52:30 +09:00
def reset(self):
obs = self._env.reset()
obs["reward"] = 0.0
return obs
2023-04-15 23:16:43 +09:00
class SelectAction:
2023-04-23 22:52:30 +09:00
def __init__(self, env, key):
self._env = env
self._key = key
2023-04-15 23:16:43 +09:00
2023-04-23 22:52:30 +09:00
def __getattr__(self, name):
return getattr(self._env, name)
2023-04-15 23:16:43 +09:00
2023-04-23 22:52:30 +09:00
def step(self, action):
return self._env.step(action[self._key])