modified wrappers
This commit is contained in:
parent
3f6659d365
commit
a6ad132198
@ -34,11 +34,9 @@ class NormalizeActions(gym.Wrapper):
|
||||
)
|
||||
self._low = np.where(self._mask, env.action_space.low, -1)
|
||||
self._high = np.where(self._mask, env.action_space.high, 1)
|
||||
|
||||
def action_space(self):
|
||||
low = np.where(self._mask, -np.ones_like(self._low), self._low)
|
||||
high = np.where(self._mask, np.ones_like(self._low), self._high)
|
||||
return gym.spaces.Box(low, high, dtype=np.float32)
|
||||
self.action_space = gym.spaces.Box(low, high, dtype=np.float32)
|
||||
|
||||
def step(self, action):
|
||||
original = (action + 1) / 2 * (self._high - self._low) + self._low
|
||||
@ -51,13 +49,10 @@ class OneHotAction(gym.Wrapper):
|
||||
assert isinstance(env.action_space, gym.spaces.Discrete)
|
||||
super().__init__(env)
|
||||
self._random = np.random.RandomState()
|
||||
|
||||
def action_space(self):
|
||||
shape = (self.env.action_space.n,)
|
||||
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||
space.sample = self._sample_action
|
||||
space.discrete = True
|
||||
return space
|
||||
self.action_space = space
|
||||
|
||||
def step(self, action):
|
||||
index = np.argmax(action).astype(int)
|
||||
@ -81,25 +76,23 @@ class OneHotAction(gym.Wrapper):
|
||||
class RewardObs(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def observation_space(self):
|
||||
spaces = self.env.observation_space.spaces
|
||||
if "reward" not in spaces:
|
||||
spaces["reward"] = gym.spaces.Box(
|
||||
if "obs_reward" not in spaces:
|
||||
spaces["obs_reward"] = gym.spaces.Box(
|
||||
-np.inf, np.inf, shape=(1,), dtype=np.float32
|
||||
)
|
||||
return gym.spaces.Dict(spaces)
|
||||
self.observation_space = gym.spaces.Dict(spaces)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
if "reward" not in obs:
|
||||
obs["reward"] = reward
|
||||
if "obs_reward" not in obs:
|
||||
obs["obs_reward"] = np.array([reward], dtype=np.float32)
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
obs = self.env.reset()
|
||||
if "reward" not in obs:
|
||||
obs["reward"] = 0.0
|
||||
if "obs_reward" not in obs:
|
||||
obs["obs_reward"] = np.array([0.0], dtype=np.float32)
|
||||
return obs
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user