modified wrappers not to use __getattr__
This commit is contained in:
parent
2c933da684
commit
12e6c68f6b
@ -4,18 +4,15 @@ import numpy as np
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
class TimeLimit:
|
class TimeLimit(gym.Wrapper):
|
||||||
def __init__(self, env, duration):
|
def __init__(self, env, duration):
|
||||||
self._env = env
|
super().__init__(env)
|
||||||
self._duration = duration
|
self._duration = duration
|
||||||
self._step = None
|
self._step = None
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self._env, name)
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
assert self._step is not None, "Must reset environment."
|
assert self._step is not None, "Must reset environment."
|
||||||
obs, reward, done, info = self._env.step(action)
|
obs, reward, done, info = self.env.step(action)
|
||||||
self._step += 1
|
self._step += 1
|
||||||
if self._step >= self._duration:
|
if self._step >= self._duration:
|
||||||
done = True
|
done = True
|
||||||
@ -26,22 +23,18 @@ class TimeLimit:
|
|||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self._step = 0
|
self._step = 0
|
||||||
return self._env.reset()
|
return self.env.reset()
|
||||||
|
|
||||||
|
|
||||||
class NormalizeActions:
|
class NormalizeActions(gym.Wrapper):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
self._env = env
|
super().__init__(env)
|
||||||
self._mask = np.logical_and(
|
self._mask = np.logical_and(
|
||||||
np.isfinite(env.action_space.low), np.isfinite(env.action_space.high)
|
np.isfinite(env.action_space.low), np.isfinite(env.action_space.high)
|
||||||
)
|
)
|
||||||
self._low = np.where(self._mask, env.action_space.low, -1)
|
self._low = np.where(self._mask, env.action_space.low, -1)
|
||||||
self._high = np.where(self._mask, env.action_space.high, 1)
|
self._high = np.where(self._mask, env.action_space.high, 1)
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self._env, name)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def action_space(self):
|
def action_space(self):
|
||||||
low = np.where(self._mask, -np.ones_like(self._low), self._low)
|
low = np.where(self._mask, -np.ones_like(self._low), self._low)
|
||||||
high = np.where(self._mask, np.ones_like(self._low), self._high)
|
high = np.where(self._mask, np.ones_like(self._low), self._high)
|
||||||
@ -50,21 +43,18 @@ class NormalizeActions:
|
|||||||
def step(self, action):
|
def step(self, action):
|
||||||
original = (action + 1) / 2 * (self._high - self._low) + self._low
|
original = (action + 1) / 2 * (self._high - self._low) + self._low
|
||||||
original = np.where(self._mask, original, action)
|
original = np.where(self._mask, original, action)
|
||||||
return self._env.step(original)
|
return self.env.step(original)
|
||||||
|
|
||||||
|
|
||||||
class OneHotAction:
|
class OneHotAction(gym.Wrapper):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
assert isinstance(env.action_space, gym.spaces.Discrete)
|
assert isinstance(env.action_space, gym.spaces.Discrete)
|
||||||
self._env = env
|
super().__init__(env)
|
||||||
self._random = np.random.RandomState()
|
self._random = np.random.RandomState()
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self._env, name)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def action_space(self):
|
def action_space(self):
|
||||||
shape = (self._env.action_space.n,)
|
shape = (self.env.action_space.n,)
|
||||||
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||||
space.sample = self._sample_action
|
space.sample = self._sample_action
|
||||||
space.discrete = True
|
space.discrete = True
|
||||||
@ -76,29 +66,26 @@ class OneHotAction:
|
|||||||
reference[index] = 1
|
reference[index] = 1
|
||||||
if not np.allclose(reference, action):
|
if not np.allclose(reference, action):
|
||||||
raise ValueError(f"Invalid one-hot action:\n{action}")
|
raise ValueError(f"Invalid one-hot action:\n{action}")
|
||||||
return self._env.step(index)
|
return self.env.step(index)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
return self._env.reset()
|
return self.env.reset()
|
||||||
|
|
||||||
def _sample_action(self):
|
def _sample_action(self):
|
||||||
actions = self._env.action_space.n
|
actions = self.env.action_space.n
|
||||||
index = self._random.randint(0, actions)
|
index = self._random.randint(0, actions)
|
||||||
reference = np.zeros(actions, dtype=np.float32)
|
reference = np.zeros(actions, dtype=np.float32)
|
||||||
reference[index] = 1.0
|
reference[index] = 1.0
|
||||||
return reference
|
return reference
|
||||||
|
|
||||||
|
|
||||||
class RewardObs:
|
class RewardObs(gym.Wrapper):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
self._env = env
|
super().__init__(env)
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self._env, name)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
spaces = self._env.observation_space.spaces
|
spaces = self.env.observation_space.spaces
|
||||||
if "reward" not in spaces:
|
if "reward" not in spaces:
|
||||||
spaces["reward"] = gym.spaces.Box(
|
spaces["reward"] = gym.spaces.Box(
|
||||||
-np.inf, np.inf, shape=(1,), dtype=np.float32
|
-np.inf, np.inf, shape=(1,), dtype=np.float32
|
||||||
@ -106,39 +93,35 @@ class RewardObs:
|
|||||||
return gym.spaces.Dict(spaces)
|
return gym.spaces.Dict(spaces)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, done, info = self._env.step(action)
|
obs, reward, done, info = self.env.step(action)
|
||||||
if "reward" not in obs:
|
if "reward" not in obs:
|
||||||
obs["reward"] = reward
|
obs["reward"] = reward
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
obs = self._env.reset()
|
obs = self.env.reset()
|
||||||
if "reward" not in obs:
|
if "reward" not in obs:
|
||||||
obs["reward"] = 0.0
|
obs["reward"] = 0.0
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
|
|
||||||
class SelectAction:
|
class SelectAction(gym.Wrapper):
|
||||||
def __init__(self, env, key):
|
def __init__(self, env, key):
|
||||||
self._env = env
|
super().__init__(env)
|
||||||
self._key = key
|
self._key = key
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self._env, name)
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
return self._env.step(action[self._key])
|
return self.env.step(action[self._key])
|
||||||
|
|
||||||
class UUID:
|
class UUID(gym.Wrapper):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
self._env = env
|
super().__init__(env)
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
|
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
|
||||||
self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
|
self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self._env, name)
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
|
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
|
||||||
self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
|
self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
|
||||||
return self._env.reset()
|
return self.env.reset()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user