modified wrappers

This commit is contained in:
NM512 2023-08-05 21:08:01 +09:00
parent 3f6659d365
commit a6ad132198

View File

@ -34,11 +34,9 @@ class NormalizeActions(gym.Wrapper):
)
self._low = np.where(self._mask, env.action_space.low, -1)
self._high = np.where(self._mask, env.action_space.high, 1)
def action_space(self):
low = np.where(self._mask, -np.ones_like(self._low), self._low)
high = np.where(self._mask, np.ones_like(self._low), self._high)
return gym.spaces.Box(low, high, dtype=np.float32)
self.action_space = gym.spaces.Box(low, high, dtype=np.float32)
def step(self, action):
original = (action + 1) / 2 * (self._high - self._low) + self._low
@ -51,13 +49,10 @@ class OneHotAction(gym.Wrapper):
assert isinstance(env.action_space, gym.spaces.Discrete)
super().__init__(env)
self._random = np.random.RandomState()
def action_space(self):
shape = (self.env.action_space.n,)
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
space.sample = self._sample_action
space.discrete = True
return space
self.action_space = space
def step(self, action):
index = np.argmax(action).astype(int)
@ -81,25 +76,23 @@ class OneHotAction(gym.Wrapper):
class RewardObs(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
def observation_space(self):
spaces = self.env.observation_space.spaces
if "reward" not in spaces:
spaces["reward"] = gym.spaces.Box(
if "obs_reward" not in spaces:
spaces["obs_reward"] = gym.spaces.Box(
-np.inf, np.inf, shape=(1,), dtype=np.float32
)
return gym.spaces.Dict(spaces)
self.observation_space = gym.spaces.Dict(spaces)
def step(self, action):
obs, reward, done, info = self.env.step(action)
if "reward" not in obs:
obs["reward"] = reward
if "obs_reward" not in obs:
obs["obs_reward"] = np.array([reward], dtype=np.float32)
return obs, reward, done, info
def reset(self):
obs = self.env.reset()
if "reward" not in obs:
obs["reward"] = 0.0
if "obs_reward" not in obs:
obs["obs_reward"] = np.array([0.0], dtype=np.float32)
return obs