modified wrappers
This commit is contained in:
		
							parent
							
								
									3f6659d365
								
							
						
					
					
						commit
						a6ad132198
					
				| @ -34,11 +34,9 @@ class NormalizeActions(gym.Wrapper): | ||||
|         ) | ||||
|         self._low = np.where(self._mask, env.action_space.low, -1) | ||||
|         self._high = np.where(self._mask, env.action_space.high, 1) | ||||
| 
 | ||||
|     def action_space(self): | ||||
|         low = np.where(self._mask, -np.ones_like(self._low), self._low) | ||||
|         high = np.where(self._mask, np.ones_like(self._low), self._high) | ||||
|         return gym.spaces.Box(low, high, dtype=np.float32) | ||||
|         self.action_space = gym.spaces.Box(low, high, dtype=np.float32) | ||||
| 
 | ||||
|     def step(self, action): | ||||
|         original = (action + 1) / 2 * (self._high - self._low) + self._low | ||||
| @ -51,13 +49,10 @@ class OneHotAction(gym.Wrapper): | ||||
|         assert isinstance(env.action_space, gym.spaces.Discrete) | ||||
|         super().__init__(env) | ||||
|         self._random = np.random.RandomState() | ||||
| 
 | ||||
|     def action_space(self): | ||||
|         shape = (self.env.action_space.n,) | ||||
|         space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) | ||||
|         space.sample = self._sample_action | ||||
|         space.discrete = True | ||||
|         return space | ||||
|         self.action_space = space | ||||
| 
 | ||||
|     def step(self, action): | ||||
|         index = np.argmax(action).astype(int) | ||||
| @ -81,25 +76,23 @@ class OneHotAction(gym.Wrapper): | ||||
| class RewardObs(gym.Wrapper): | ||||
|     def __init__(self, env): | ||||
|         super().__init__(env) | ||||
| 
 | ||||
|     def observation_space(self): | ||||
|         spaces = self.env.observation_space.spaces | ||||
|         if "reward" not in spaces: | ||||
|             spaces["reward"] = gym.spaces.Box( | ||||
|         if "obs_reward" not in spaces: | ||||
|             spaces["obs_reward"] = gym.spaces.Box( | ||||
|                 -np.inf, np.inf, shape=(1,), dtype=np.float32 | ||||
|             ) | ||||
|         return gym.spaces.Dict(spaces) | ||||
|         self.observation_space = gym.spaces.Dict(spaces) | ||||
| 
 | ||||
|     def step(self, action): | ||||
|         obs, reward, done, info = self.env.step(action) | ||||
|         if "reward" not in obs: | ||||
|             obs["reward"] = reward | ||||
|         if "obs_reward" not in obs: | ||||
|             obs["obs_reward"] = np.array([reward], dtype=np.float32) | ||||
|         return obs, reward, done, info | ||||
| 
 | ||||
|     def reset(self): | ||||
|         obs = self.env.reset() | ||||
|         if "reward" not in obs: | ||||
|             obs["reward"] = 0.0 | ||||
|         if "obs_reward" not in obs: | ||||
|             obs["obs_reward"] = np.array([0.0], dtype=np.float32) | ||||
|         return obs | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user