diff --git a/envs/wrappers.py b/envs/wrappers.py index b73e156..6a538fe 100644 --- a/envs/wrappers.py +++ b/envs/wrappers.py @@ -34,11 +34,9 @@ class NormalizeActions(gym.Wrapper): ) self._low = np.where(self._mask, env.action_space.low, -1) self._high = np.where(self._mask, env.action_space.high, 1) - - def action_space(self): low = np.where(self._mask, -np.ones_like(self._low), self._low) high = np.where(self._mask, np.ones_like(self._low), self._high) - return gym.spaces.Box(low, high, dtype=np.float32) + self.action_space = gym.spaces.Box(low, high, dtype=np.float32) def step(self, action): original = (action + 1) / 2 * (self._high - self._low) + self._low @@ -51,13 +49,10 @@ class OneHotAction(gym.Wrapper): assert isinstance(env.action_space, gym.spaces.Discrete) super().__init__(env) self._random = np.random.RandomState() - - def action_space(self): shape = (self.env.action_space.n,) space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) - space.sample = self._sample_action space.discrete = True - return space + self.action_space = space def step(self, action): index = np.argmax(action).astype(int) @@ -81,25 +76,23 @@ class OneHotAction(gym.Wrapper): class RewardObs(gym.Wrapper): def __init__(self, env): super().__init__(env) - - def observation_space(self): spaces = self.env.observation_space.spaces - if "reward" not in spaces: - spaces["reward"] = gym.spaces.Box( + if "obs_reward" not in spaces: + spaces["obs_reward"] = gym.spaces.Box( -np.inf, np.inf, shape=(1,), dtype=np.float32 ) - return gym.spaces.Dict(spaces) + self.observation_space = gym.spaces.Dict(spaces) def step(self, action): obs, reward, done, info = self.env.step(action) - if "reward" not in obs: - obs["reward"] = reward + if "obs_reward" not in obs: + obs["obs_reward"] = np.array([reward], dtype=np.float32) return obs, reward, done, info def reset(self): obs = self.env.reset() - if "reward" not in obs: - obs["reward"] = 0.0 + if "obs_reward" not in obs: + obs["obs_reward"] = np.array([0.0], dtype=np.float32) return obs