dreamerv3-torch/envs/wrappers.py
2023-08-05 21:08:01 +09:00

118 lines
3.7 KiB
Python

import datetime
import gym
import numpy as np
import uuid
class TimeLimit(gym.Wrapper):
def __init__(self, env, duration):
super().__init__(env)
self._duration = duration
self._step = None
def step(self, action):
assert self._step is not None, "Must reset environment."
obs, reward, done, info = self.env.step(action)
self._step += 1
if self._step >= self._duration:
done = True
if "discount" not in info:
info["discount"] = np.array(1.0).astype(np.float32)
self._step = None
return obs, reward, done, info
def reset(self):
self._step = 0
return self.env.reset()
class NormalizeActions(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self._mask = np.logical_and(
np.isfinite(env.action_space.low), np.isfinite(env.action_space.high)
)
self._low = np.where(self._mask, env.action_space.low, -1)
self._high = np.where(self._mask, env.action_space.high, 1)
low = np.where(self._mask, -np.ones_like(self._low), self._low)
high = np.where(self._mask, np.ones_like(self._low), self._high)
self.action_space = gym.spaces.Box(low, high, dtype=np.float32)
def step(self, action):
original = (action + 1) / 2 * (self._high - self._low) + self._low
original = np.where(self._mask, original, action)
return self.env.step(original)
class OneHotAction(gym.Wrapper):
def __init__(self, env):
assert isinstance(env.action_space, gym.spaces.Discrete)
super().__init__(env)
self._random = np.random.RandomState()
shape = (self.env.action_space.n,)
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
space.discrete = True
self.action_space = space
def step(self, action):
index = np.argmax(action).astype(int)
reference = np.zeros_like(action)
reference[index] = 1
if not np.allclose(reference, action):
raise ValueError(f"Invalid one-hot action:\n{action}")
return self.env.step(index)
def reset(self):
return self.env.reset()
def _sample_action(self):
actions = self.env.action_space.n
index = self._random.randint(0, actions)
reference = np.zeros(actions, dtype=np.float32)
reference[index] = 1.0
return reference
class RewardObs(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
spaces = self.env.observation_space.spaces
if "obs_reward" not in spaces:
spaces["obs_reward"] = gym.spaces.Box(
-np.inf, np.inf, shape=(1,), dtype=np.float32
)
self.observation_space = gym.spaces.Dict(spaces)
def step(self, action):
obs, reward, done, info = self.env.step(action)
if "obs_reward" not in obs:
obs["obs_reward"] = np.array([reward], dtype=np.float32)
return obs, reward, done, info
def reset(self):
obs = self.env.reset()
if "obs_reward" not in obs:
obs["obs_reward"] = np.array([0.0], dtype=np.float32)
return obs
class SelectAction(gym.Wrapper):
def __init__(self, env, key):
super().__init__(env)
self._key = key
def step(self, action):
return self.env.step(action[self._key])
class UUID(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
def reset(self):
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
return self.env.reset()