modified wrappers

This commit is contained in:
NM512 2023-08-05 21:08:01 +09:00
parent 3f6659d365
commit a6ad132198

View File

@ -34,11 +34,9 @@ class NormalizeActions(gym.Wrapper):
) )
self._low = np.where(self._mask, env.action_space.low, -1) self._low = np.where(self._mask, env.action_space.low, -1)
self._high = np.where(self._mask, env.action_space.high, 1) self._high = np.where(self._mask, env.action_space.high, 1)
def action_space(self):
low = np.where(self._mask, -np.ones_like(self._low), self._low) low = np.where(self._mask, -np.ones_like(self._low), self._low)
high = np.where(self._mask, np.ones_like(self._low), self._high) high = np.where(self._mask, np.ones_like(self._low), self._high)
return gym.spaces.Box(low, high, dtype=np.float32) self.action_space = gym.spaces.Box(low, high, dtype=np.float32)
def step(self, action): def step(self, action):
original = (action + 1) / 2 * (self._high - self._low) + self._low original = (action + 1) / 2 * (self._high - self._low) + self._low
@ -51,13 +49,10 @@ class OneHotAction(gym.Wrapper):
assert isinstance(env.action_space, gym.spaces.Discrete) assert isinstance(env.action_space, gym.spaces.Discrete)
super().__init__(env) super().__init__(env)
self._random = np.random.RandomState() self._random = np.random.RandomState()
def action_space(self):
shape = (self.env.action_space.n,) shape = (self.env.action_space.n,)
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
space.sample = self._sample_action
space.discrete = True space.discrete = True
return space self.action_space = space
def step(self, action): def step(self, action):
index = np.argmax(action).astype(int) index = np.argmax(action).astype(int)
@ -81,25 +76,23 @@ class OneHotAction(gym.Wrapper):
class RewardObs(gym.Wrapper): class RewardObs(gym.Wrapper):
def __init__(self, env): def __init__(self, env):
super().__init__(env) super().__init__(env)
def observation_space(self):
spaces = self.env.observation_space.spaces spaces = self.env.observation_space.spaces
if "reward" not in spaces: if "obs_reward" not in spaces:
spaces["reward"] = gym.spaces.Box( spaces["obs_reward"] = gym.spaces.Box(
-np.inf, np.inf, shape=(1,), dtype=np.float32 -np.inf, np.inf, shape=(1,), dtype=np.float32
) )
return gym.spaces.Dict(spaces) self.observation_space = gym.spaces.Dict(spaces)
def step(self, action): def step(self, action):
obs, reward, done, info = self.env.step(action) obs, reward, done, info = self.env.step(action)
if "reward" not in obs: if "obs_reward" not in obs:
obs["reward"] = reward obs["obs_reward"] = np.array([reward], dtype=np.float32)
return obs, reward, done, info return obs, reward, done, info
def reset(self): def reset(self):
obs = self.env.reset() obs = self.env.reset()
if "reward" not in obs: if "obs_reward" not in obs:
obs["reward"] = 0.0 obs["obs_reward"] = np.array([0.0], dtype=np.float32)
return obs return obs