applied formatter
This commit is contained in:
		
							parent
							
								
									775eb94e7f
								
							
						
					
					
						commit
						e3329b35e5
					
				
							
								
								
									
										10
									
								
								dreamer.py
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								dreamer.py
									
									
									
									
									
								
							| @ -213,13 +213,15 @@ def make_env(config, logger, mode, train_eps, eval_eps): | |||||||
|         env = wrappers.OneHotAction(env) |         env = wrappers.OneHotAction(env) | ||||||
|     elif suite == "MemoryMaze": |     elif suite == "MemoryMaze": | ||||||
|         import gym |         import gym | ||||||
|         if task == '9x9': | 
 | ||||||
|             env = gym.make('memory_maze:MemoryMaze-9x9-v0') |         if task == "9x9": | ||||||
|         elif task == '15x15': |             env = gym.make("memory_maze:MemoryMaze-9x9-v0") | ||||||
|             env = gym.make('memory_maze:MemoryMaze-15x15-v0') |         elif task == "15x15": | ||||||
|  |             env = gym.make("memory_maze:MemoryMaze-15x15-v0") | ||||||
|         else: |         else: | ||||||
|             raise NotImplementedError(suite) |             raise NotImplementedError(suite) | ||||||
|         from envs.memorymaze import MemoryMaze |         from envs.memorymaze import MemoryMaze | ||||||
|  | 
 | ||||||
|         env = MemoryMaze(env) |         env = MemoryMaze(env) | ||||||
|         env = wrappers.OneHotAction(env) |         env = wrappers.OneHotAction(env) | ||||||
|     elif suite == "crafter": |     elif suite == "crafter": | ||||||
|  | |||||||
| @ -8,80 +8,79 @@ import numpy as np | |||||||
| 
 | 
 | ||||||
| ###from tf dreamerv2 code | ###from tf dreamerv2 code | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| class MemoryMaze: | class MemoryMaze: | ||||||
|  |     def __init__(self, env, obs_key="image", act_key="action", size=(64, 64)): | ||||||
|  |         self._env = env | ||||||
|  |         self._obs_is_dict = hasattr(self._env.observation_space, "spaces") | ||||||
|  |         self._act_is_dict = hasattr(self._env.action_space, "spaces") | ||||||
|  |         self._obs_key = obs_key | ||||||
|  |         self._act_key = act_key | ||||||
|  |         self._size = size | ||||||
|  |         self._gray = False | ||||||
| 
 | 
 | ||||||
|   def __init__(self, env, obs_key='image', act_key='action', size=(64, 64)): |     def __getattr__(self, name): | ||||||
|     self._env = env |         if name.startswith("__"): | ||||||
|     self._obs_is_dict = hasattr(self._env.observation_space, 'spaces') |             raise AttributeError(name) | ||||||
|     self._act_is_dict = hasattr(self._env.action_space, 'spaces') |         try: | ||||||
|     self._obs_key = obs_key |             return getattr(self._env, name) | ||||||
|     self._act_key = act_key |         except AttributeError: | ||||||
|     self._size = size |             raise ValueError(name) | ||||||
|     self._gray = False |  | ||||||
| 
 | 
 | ||||||
|   def __getattr__(self, name): |     @property | ||||||
|     if name.startswith('__'): |     def obs_space(self): | ||||||
|       raise AttributeError(name) |         if self._obs_is_dict: | ||||||
|     try: |             spaces = self._env.observation_space.spaces.copy() | ||||||
|       return getattr(self._env, name) |         else: | ||||||
|     except AttributeError: |             spaces = {self._obs_key: self._env.observation_space} | ||||||
|       raise ValueError(name) |         return { | ||||||
|  |             **spaces, | ||||||
|  |             "reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), | ||||||
|  |             "is_first": gym.spaces.Box(0, 1, (), dtype=np.bool), | ||||||
|  |             "is_last": gym.spaces.Box(0, 1, (), dtype=np.bool), | ||||||
|  |             "is_terminal": gym.spaces.Box(0, 1, (), dtype=np.bool), | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|   @property |     @property | ||||||
|   def obs_space(self): |     def act_space(self): | ||||||
|     if self._obs_is_dict: |         if self._act_is_dict: | ||||||
|       spaces = self._env.observation_space.spaces.copy() |             return self._env.action_space.spaces.copy() | ||||||
|     else: |         else: | ||||||
|       spaces = {self._obs_key: self._env.observation_space} |             return {self._act_key: self._env.action_space} | ||||||
|     return { |  | ||||||
|         **spaces, |  | ||||||
|         'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), |  | ||||||
|         'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), |  | ||||||
|         'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), |  | ||||||
|         'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), |  | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
|   @property |     @property | ||||||
|   def act_space(self): |     def observation_space(self): | ||||||
|     if self._act_is_dict: |         img_shape = self._size + ((1,) if self._gray else (3,)) | ||||||
|       return self._env.action_space.spaces.copy() |         return gym.spaces.Dict( | ||||||
|     else: |             { | ||||||
|       return {self._act_key: self._env.action_space} |                 "image": gym.spaces.Box(0, 255, img_shape, np.uint8), | ||||||
|  |             } | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|   @property |     @property | ||||||
|   def observation_space(self): |     def action_space(self): | ||||||
|     img_shape = self._size + ((1,) if self._gray else (3,)) |         space = self._env.action_space | ||||||
|     return gym.spaces.Dict( |         space.discrete = True | ||||||
|       { |         return space | ||||||
|         "image": gym.spaces.Box(0, 255, img_shape, np.uint8), |  | ||||||
|       } |  | ||||||
|     ) |  | ||||||
| 
 | 
 | ||||||
|   @property |     def step(self, action): | ||||||
|   def action_space(self): |         # if not self._act_is_dict: | ||||||
|     space = self._env.action_space |         #   action = action[self._act_key] | ||||||
|     space.discrete = True |         obs, reward, done, info = self._env.step(action) | ||||||
|     return space |         if not self._obs_is_dict: | ||||||
| 
 |             obs = {self._obs_key: obs} | ||||||
|   def step(self, action): |         # obs['reward'] = float(reward) | ||||||
|     # if not self._act_is_dict: |         obs["is_first"] = False | ||||||
|     #   action = action[self._act_key] |         obs["is_last"] = done | ||||||
|     obs, reward, done, info = self._env.step(action) |         obs["is_terminal"] = info.get("is_terminal", False) | ||||||
|     if not self._obs_is_dict: |         return obs, reward, done, info | ||||||
|       obs = {self._obs_key: obs} |  | ||||||
|     # obs['reward'] = float(reward) |  | ||||||
|     obs['is_first'] = False |  | ||||||
|     obs['is_last'] = done |  | ||||||
|     obs['is_terminal'] = info.get('is_terminal', False) |  | ||||||
|     return obs, reward, done, info |  | ||||||
| 
 |  | ||||||
|   def reset(self): |  | ||||||
|     obs = self._env.reset() |  | ||||||
|     if not self._obs_is_dict: |  | ||||||
|       obs = {self._obs_key: obs} |  | ||||||
|     obs['reward'] = 0.0 |  | ||||||
|     obs['is_first'] = True |  | ||||||
|     obs['is_last'] = False |  | ||||||
|     obs['is_terminal'] = False |  | ||||||
|     return obs |  | ||||||
| 
 | 
 | ||||||
|  |     def reset(self): | ||||||
|  |         obs = self._env.reset() | ||||||
|  |         if not self._obs_is_dict: | ||||||
|  |             obs = {self._obs_key: obs} | ||||||
|  |         obs["reward"] = 0.0 | ||||||
|  |         obs["is_first"] = True | ||||||
|  |         obs["is_last"] = False | ||||||
|  |         obs["is_terminal"] = False | ||||||
|  |         return obs | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user