diff --git a/envs/wrappers.py b/envs/wrappers.py index 80f4f19..af52602 100644 --- a/envs/wrappers.py +++ b/envs/wrappers.py @@ -4,18 +4,15 @@ import numpy as np import uuid -class TimeLimit: +class TimeLimit(gym.Wrapper): def __init__(self, env, duration): - self._env = env + super().__init__(env) self._duration = duration self._step = None - def __getattr__(self, name): - return getattr(self._env, name) - def step(self, action): assert self._step is not None, "Must reset environment." - obs, reward, done, info = self._env.step(action) + obs, reward, done, info = self.env.step(action) self._step += 1 if self._step >= self._duration: done = True @@ -26,22 +23,18 @@ class TimeLimit: def reset(self): self._step = 0 - return self._env.reset() + return self.env.reset() -class NormalizeActions: +class NormalizeActions(gym.Wrapper): def __init__(self, env): - self._env = env + super().__init__(env) self._mask = np.logical_and( np.isfinite(env.action_space.low), np.isfinite(env.action_space.high) ) self._low = np.where(self._mask, env.action_space.low, -1) self._high = np.where(self._mask, env.action_space.high, 1) - def __getattr__(self, name): - return getattr(self._env, name) - - @property def action_space(self): low = np.where(self._mask, -np.ones_like(self._low), self._low) high = np.where(self._mask, np.ones_like(self._low), self._high) @@ -50,21 +43,18 @@ class NormalizeActions: def step(self, action): original = (action + 1) / 2 * (self._high - self._low) + self._low original = np.where(self._mask, original, action) - return self._env.step(original) + return self.env.step(original) -class OneHotAction: +class OneHotAction(gym.Wrapper): def __init__(self, env): assert isinstance(env.action_space, gym.spaces.Discrete) - self._env = env + super().__init__(env) self._random = np.random.RandomState() - def __getattr__(self, name): - return getattr(self._env, name) - @property def action_space(self): - shape = (self._env.action_space.n,) + shape = (self.env.action_space.n,) space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) space.sample = self._sample_action space.discrete = True @@ -76,29 +66,26 @@ class OneHotAction: reference[index] = 1 if not np.allclose(reference, action): raise ValueError(f"Invalid one-hot action:\n{action}") - return self._env.step(index) + return self.env.step(index) def reset(self): - return self._env.reset() + return self.env.reset() def _sample_action(self): - actions = self._env.action_space.n + actions = self.env.action_space.n index = self._random.randint(0, actions) reference = np.zeros(actions, dtype=np.float32) reference[index] = 1.0 return reference -class RewardObs: +class RewardObs(gym.Wrapper): def __init__(self, env): - self._env = env + super().__init__(env) - def __getattr__(self, name): - return getattr(self._env, name) - @property def observation_space(self): - spaces = self._env.observation_space.spaces + spaces = self.env.observation_space.spaces if "reward" not in spaces: spaces["reward"] = gym.spaces.Box( -np.inf, np.inf, shape=(1,), dtype=np.float32 @@ -106,39 +93,35 @@ class RewardObs: return gym.spaces.Dict(spaces) def step(self, action): - obs, reward, done, info = self._env.step(action) + obs, reward, done, info = self.env.step(action) if "reward" not in obs: obs["reward"] = reward return obs, reward, done, info def reset(self): - obs = self._env.reset() + obs = self.env.reset() if "reward" not in obs: obs["reward"] = 0.0 return obs -class SelectAction: +class SelectAction(gym.Wrapper): def __init__(self, env, key): - self._env = env + super().__init__(env) self._key = key - def __getattr__(self, name): - return getattr(self._env, name) def step(self, action): - return self._env.step(action[self._key]) + return self.env.step(action[self._key]) -class UUID: +class UUID(gym.Wrapper): def __init__(self, env): - self._env = env + super().__init__(env) timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S") self.id = f"{timestamp}-{str(uuid.uuid4().hex)}" - def __getattr__(self, name): - return getattr(self._env, name) def reset(self): timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S") self.id = f"{timestamp}-{str(uuid.uuid4().hex)}" - return self._env.reset() + return self.env.reset()