commit
						775eb94e7f
					
				
							
								
								
									
										12
									
								
								configs.yaml
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								configs.yaml
									
									
									
									
									
								
							| @ -172,9 +172,21 @@ atari100k: | |||||||
|   imag_gradient: 'reinforce' |   imag_gradient: 'reinforce' | ||||||
|   time_limit: 108000 |   time_limit: 108000 | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| debug: | debug: | ||||||
|   debug: True |   debug: True | ||||||
|   pretrain: 1 |   pretrain: 1 | ||||||
|   prefill: 1 |   prefill: 1 | ||||||
|   batch_size: 10 |   batch_size: 10 | ||||||
|   batch_length: 20 |   batch_length: 20 | ||||||
|  | 
 | ||||||
|  | MemoryMaze: | ||||||
|  |   actor_dist: 'onehot' | ||||||
|  |   imag_gradient: 'reinforce' | ||||||
|  |   task: '9x9' | ||||||
|  |   steps: 1e6 | ||||||
|  |   action_repeat: 2 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | |||||||
							
								
								
									
										11
									
								
								dreamer.py
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								dreamer.py
									
									
									
									
									
								
							| @ -211,6 +211,17 @@ def make_env(config, logger, mode, train_eps, eval_eps): | |||||||
|             task, mode if "train" in mode else "test", config.action_repeat |             task, mode if "train" in mode else "test", config.action_repeat | ||||||
|         ) |         ) | ||||||
|         env = wrappers.OneHotAction(env) |         env = wrappers.OneHotAction(env) | ||||||
|  |     elif suite == "MemoryMaze": | ||||||
|  |         import gym | ||||||
|  |         if task == '9x9': | ||||||
|  |             env = gym.make('memory_maze:MemoryMaze-9x9-v0') | ||||||
|  |         elif task == '15x15': | ||||||
|  |             env = gym.make('memory_maze:MemoryMaze-15x15-v0') | ||||||
|  |         else: | ||||||
|  |             raise NotImplementedError(suite) | ||||||
|  |         from envs.memorymaze import MemoryMaze | ||||||
|  |         env = MemoryMaze(env) | ||||||
|  |         env = wrappers.OneHotAction(env) | ||||||
|     elif suite == "crafter": |     elif suite == "crafter": | ||||||
|         import envs.crafter as crafter |         import envs.crafter as crafter | ||||||
| 
 | 
 | ||||||
|  | |||||||
							
								
								
									
										87
									
								
								envs/memorymaze.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								envs/memorymaze.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,87 @@ | |||||||
|  | import atexit | ||||||
|  | import os | ||||||
|  | import sys | ||||||
|  | 
 | ||||||
|  | import cloudpickle | ||||||
|  | import gym | ||||||
|  | import numpy as np | ||||||
|  | 
 | ||||||
|  | ###from tf dreamerv2 code | ||||||
|  | 
 | ||||||
|  | class MemoryMaze: | ||||||
|  | 
 | ||||||
|  |   def __init__(self, env, obs_key='image', act_key='action', size=(64, 64)): | ||||||
|  |     self._env = env | ||||||
|  |     self._obs_is_dict = hasattr(self._env.observation_space, 'spaces') | ||||||
|  |     self._act_is_dict = hasattr(self._env.action_space, 'spaces') | ||||||
|  |     self._obs_key = obs_key | ||||||
|  |     self._act_key = act_key | ||||||
|  |     self._size = size | ||||||
|  |     self._gray = False | ||||||
|  | 
 | ||||||
|  |   def __getattr__(self, name): | ||||||
|  |     if name.startswith('__'): | ||||||
|  |       raise AttributeError(name) | ||||||
|  |     try: | ||||||
|  |       return getattr(self._env, name) | ||||||
|  |     except AttributeError: | ||||||
|  |       raise ValueError(name) | ||||||
|  | 
 | ||||||
|  |   @property | ||||||
|  |   def obs_space(self): | ||||||
|  |     if self._obs_is_dict: | ||||||
|  |       spaces = self._env.observation_space.spaces.copy() | ||||||
|  |     else: | ||||||
|  |       spaces = {self._obs_key: self._env.observation_space} | ||||||
|  |     return { | ||||||
|  |         **spaces, | ||||||
|  |         'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), | ||||||
|  |         'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), | ||||||
|  |         'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), | ||||||
|  |         'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |   @property | ||||||
|  |   def act_space(self): | ||||||
|  |     if self._act_is_dict: | ||||||
|  |       return self._env.action_space.spaces.copy() | ||||||
|  |     else: | ||||||
|  |       return {self._act_key: self._env.action_space} | ||||||
|  | 
 | ||||||
|  |   @property | ||||||
|  |   def observation_space(self): | ||||||
|  |     img_shape = self._size + ((1,) if self._gray else (3,)) | ||||||
|  |     return gym.spaces.Dict( | ||||||
|  |       { | ||||||
|  |         "image": gym.spaces.Box(0, 255, img_shape, np.uint8), | ||||||
|  |       } | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |   @property | ||||||
|  |   def action_space(self): | ||||||
|  |     space = self._env.action_space | ||||||
|  |     space.discrete = True | ||||||
|  |     return space | ||||||
|  | 
 | ||||||
|  |   def step(self, action): | ||||||
|  |     # if not self._act_is_dict: | ||||||
|  |     #   action = action[self._act_key] | ||||||
|  |     obs, reward, done, info = self._env.step(action) | ||||||
|  |     if not self._obs_is_dict: | ||||||
|  |       obs = {self._obs_key: obs} | ||||||
|  |     # obs['reward'] = float(reward) | ||||||
|  |     obs['is_first'] = False | ||||||
|  |     obs['is_last'] = done | ||||||
|  |     obs['is_terminal'] = info.get('is_terminal', False) | ||||||
|  |     return obs, reward, done, info | ||||||
|  | 
 | ||||||
|  |   def reset(self): | ||||||
|  |     obs = self._env.reset() | ||||||
|  |     if not self._obs_is_dict: | ||||||
|  |       obs = {self._obs_key: obs} | ||||||
|  |     obs['reward'] = 0.0 | ||||||
|  |     obs['is_first'] = True | ||||||
|  |     obs['is_last'] = False | ||||||
|  |     obs['is_terminal'] = False | ||||||
|  |     return obs | ||||||
|  | 
 | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user