applied formatter
This commit is contained in:
		
							parent
							
								
									775eb94e7f
								
							
						
					
					
						commit
						e3329b35e5
					
				
							
								
								
									
										10
									
								
								dreamer.py
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								dreamer.py
									
									
									
									
									
								
							| @ -213,13 +213,15 @@ def make_env(config, logger, mode, train_eps, eval_eps): | ||||
|         env = wrappers.OneHotAction(env) | ||||
|     elif suite == "MemoryMaze": | ||||
|         import gym | ||||
|         if task == '9x9': | ||||
|             env = gym.make('memory_maze:MemoryMaze-9x9-v0') | ||||
|         elif task == '15x15': | ||||
|             env = gym.make('memory_maze:MemoryMaze-15x15-v0') | ||||
| 
 | ||||
|         if task == "9x9": | ||||
|             env = gym.make("memory_maze:MemoryMaze-9x9-v0") | ||||
|         elif task == "15x15": | ||||
|             env = gym.make("memory_maze:MemoryMaze-15x15-v0") | ||||
|         else: | ||||
|             raise NotImplementedError(suite) | ||||
|         from envs.memorymaze import MemoryMaze | ||||
| 
 | ||||
|         env = MemoryMaze(env) | ||||
|         env = wrappers.OneHotAction(env) | ||||
|     elif suite == "crafter": | ||||
|  | ||||
| @ -8,80 +8,79 @@ import numpy as np | ||||
| 
 | ||||
| ###from tf dreamerv2 code | ||||
| 
 | ||||
| 
 | ||||
| class MemoryMaze: | ||||
|     def __init__(self, env, obs_key="image", act_key="action", size=(64, 64)): | ||||
|         self._env = env | ||||
|         self._obs_is_dict = hasattr(self._env.observation_space, "spaces") | ||||
|         self._act_is_dict = hasattr(self._env.action_space, "spaces") | ||||
|         self._obs_key = obs_key | ||||
|         self._act_key = act_key | ||||
|         self._size = size | ||||
|         self._gray = False | ||||
| 
 | ||||
|   def __init__(self, env, obs_key='image', act_key='action', size=(64, 64)): | ||||
|     self._env = env | ||||
|     self._obs_is_dict = hasattr(self._env.observation_space, 'spaces') | ||||
|     self._act_is_dict = hasattr(self._env.action_space, 'spaces') | ||||
|     self._obs_key = obs_key | ||||
|     self._act_key = act_key | ||||
|     self._size = size | ||||
|     self._gray = False | ||||
|     def __getattr__(self, name): | ||||
|         if name.startswith("__"): | ||||
|             raise AttributeError(name) | ||||
|         try: | ||||
|             return getattr(self._env, name) | ||||
|         except AttributeError: | ||||
|             raise ValueError(name) | ||||
| 
 | ||||
|   def __getattr__(self, name): | ||||
|     if name.startswith('__'): | ||||
|       raise AttributeError(name) | ||||
|     try: | ||||
|       return getattr(self._env, name) | ||||
|     except AttributeError: | ||||
|       raise ValueError(name) | ||||
|     @property | ||||
|     def obs_space(self): | ||||
|         if self._obs_is_dict: | ||||
|             spaces = self._env.observation_space.spaces.copy() | ||||
|         else: | ||||
|             spaces = {self._obs_key: self._env.observation_space} | ||||
|         return { | ||||
|             **spaces, | ||||
|             "reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), | ||||
|             "is_first": gym.spaces.Box(0, 1, (), dtype=np.bool), | ||||
|             "is_last": gym.spaces.Box(0, 1, (), dtype=np.bool), | ||||
|             "is_terminal": gym.spaces.Box(0, 1, (), dtype=np.bool), | ||||
|         } | ||||
| 
 | ||||
|   @property | ||||
|   def obs_space(self): | ||||
|     if self._obs_is_dict: | ||||
|       spaces = self._env.observation_space.spaces.copy() | ||||
|     else: | ||||
|       spaces = {self._obs_key: self._env.observation_space} | ||||
|     return { | ||||
|         **spaces, | ||||
|         'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), | ||||
|         'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), | ||||
|         'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), | ||||
|         'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), | ||||
|     } | ||||
|     @property | ||||
|     def act_space(self): | ||||
|         if self._act_is_dict: | ||||
|             return self._env.action_space.spaces.copy() | ||||
|         else: | ||||
|             return {self._act_key: self._env.action_space} | ||||
| 
 | ||||
|   @property | ||||
|   def act_space(self): | ||||
|     if self._act_is_dict: | ||||
|       return self._env.action_space.spaces.copy() | ||||
|     else: | ||||
|       return {self._act_key: self._env.action_space} | ||||
|     @property | ||||
|     def observation_space(self): | ||||
|         img_shape = self._size + ((1,) if self._gray else (3,)) | ||||
|         return gym.spaces.Dict( | ||||
|             { | ||||
|                 "image": gym.spaces.Box(0, 255, img_shape, np.uint8), | ||||
|             } | ||||
|         ) | ||||
| 
 | ||||
|   @property | ||||
|   def observation_space(self): | ||||
|     img_shape = self._size + ((1,) if self._gray else (3,)) | ||||
|     return gym.spaces.Dict( | ||||
|       { | ||||
|         "image": gym.spaces.Box(0, 255, img_shape, np.uint8), | ||||
|       } | ||||
|     ) | ||||
|     @property | ||||
|     def action_space(self): | ||||
|         space = self._env.action_space | ||||
|         space.discrete = True | ||||
|         return space | ||||
| 
 | ||||
|   @property | ||||
|   def action_space(self): | ||||
|     space = self._env.action_space | ||||
|     space.discrete = True | ||||
|     return space | ||||
| 
 | ||||
|   def step(self, action): | ||||
|     # if not self._act_is_dict: | ||||
|     #   action = action[self._act_key] | ||||
|     obs, reward, done, info = self._env.step(action) | ||||
|     if not self._obs_is_dict: | ||||
|       obs = {self._obs_key: obs} | ||||
|     # obs['reward'] = float(reward) | ||||
|     obs['is_first'] = False | ||||
|     obs['is_last'] = done | ||||
|     obs['is_terminal'] = info.get('is_terminal', False) | ||||
|     return obs, reward, done, info | ||||
| 
 | ||||
|   def reset(self): | ||||
|     obs = self._env.reset() | ||||
|     if not self._obs_is_dict: | ||||
|       obs = {self._obs_key: obs} | ||||
|     obs['reward'] = 0.0 | ||||
|     obs['is_first'] = True | ||||
|     obs['is_last'] = False | ||||
|     obs['is_terminal'] = False | ||||
|     return obs | ||||
|     def step(self, action): | ||||
|         # if not self._act_is_dict: | ||||
|         #   action = action[self._act_key] | ||||
|         obs, reward, done, info = self._env.step(action) | ||||
|         if not self._obs_is_dict: | ||||
|             obs = {self._obs_key: obs} | ||||
|         # obs['reward'] = float(reward) | ||||
|         obs["is_first"] = False | ||||
|         obs["is_last"] = done | ||||
|         obs["is_terminal"] = info.get("is_terminal", False) | ||||
|         return obs, reward, done, info | ||||
| 
 | ||||
|     def reset(self): | ||||
|         obs = self._env.reset() | ||||
|         if not self._obs_is_dict: | ||||
|             obs = {self._obs_key: obs} | ||||
|         obs["reward"] = 0.0 | ||||
|         obs["is_first"] = True | ||||
|         obs["is_last"] = False | ||||
|         obs["is_terminal"] = False | ||||
|         return obs | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user