162 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			162 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import gym
 | |
| import numpy as np
 | |
| 
 | |
| 
 | |
| class Atari:
 | |
|     LOCK = None
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         name,
 | |
|         action_repeat=4,
 | |
|         size=(84, 84),
 | |
|         gray=True,
 | |
|         noops=0,
 | |
|         lives="unused",
 | |
|         sticky=True,
 | |
|         actions="all",
 | |
|         length=108000,
 | |
|         resize="opencv",
 | |
|         seed=None,
 | |
|     ):
 | |
|         assert size[0] == size[1]
 | |
|         assert lives in ("unused", "discount", "reset"), lives
 | |
|         assert actions in ("all", "needed"), actions
 | |
|         assert resize in ("opencv", "pillow"), resize
 | |
|         if self.LOCK is None:
 | |
|             import multiprocessing as mp
 | |
| 
 | |
|             mp = mp.get_context("spawn")
 | |
|             self.LOCK = mp.Lock()
 | |
|         self._resize = resize
 | |
|         if self._resize == "opencv":
 | |
|             import cv2
 | |
| 
 | |
|             self._cv2 = cv2
 | |
|         if self._resize == "pillow":
 | |
|             from PIL import Image
 | |
| 
 | |
|             self._image = Image
 | |
|         import gym.envs.atari
 | |
| 
 | |
|         if name == "james_bond":
 | |
|             name = "jamesbond"
 | |
|         self._repeat = action_repeat
 | |
|         self._size = size
 | |
|         self._gray = gray
 | |
|         self._noops = noops
 | |
|         self._lives = lives
 | |
|         self._sticky = sticky
 | |
|         self._length = length
 | |
|         self._random = np.random.RandomState(seed)
 | |
|         with self.LOCK:
 | |
|             self._env = gym.envs.atari.AtariEnv(
 | |
|                 game=name,
 | |
|                 obs_type="image",
 | |
|                 frameskip=1,
 | |
|                 repeat_action_probability=0.25 if sticky else 0.0,
 | |
|                 full_action_space=(actions == "all"),
 | |
|             )
 | |
|         assert self._env.unwrapped.get_action_meanings()[0] == "NOOP"
 | |
|         shape = self._env.observation_space.shape
 | |
|         self._buffer = [np.zeros(shape, np.uint8) for _ in range(2)]
 | |
|         self._ale = self._env.unwrapped.ale
 | |
|         self._last_lives = None
 | |
|         self._done = True
 | |
|         self._step = 0
 | |
| 
 | |
|     @property
 | |
|     def observation_space(self):
 | |
|         img_shape = self._size + ((1,) if self._gray else (3,))
 | |
|         return gym.spaces.Dict(
 | |
|             {
 | |
|                 "image": gym.spaces.Box(0, 255, img_shape, np.uint8),
 | |
|             }
 | |
|         )
 | |
| 
 | |
|     @property
 | |
|     def action_space(self):
 | |
|         space = self._env.action_space
 | |
|         space.discrete = True
 | |
|         return space
 | |
| 
 | |
|     def step(self, action):
 | |
|         # if action['reset'] or self._done:
 | |
|         #   with self.LOCK:
 | |
|         #     self._reset()
 | |
|         #   self._done = False
 | |
|         #   self._step = 0
 | |
|         #   return self._obs(0.0, is_first=True)
 | |
|         total = 0.0
 | |
|         dead = False
 | |
|         if len(action.shape) >= 1:
 | |
|             action = np.argmax(action)
 | |
|         for repeat in range(self._repeat):
 | |
|             _, reward, over, info = self._env.step(action)
 | |
|             self._step += 1
 | |
|             total += reward
 | |
|             if repeat == self._repeat - 2:
 | |
|                 self._screen(self._buffer[1])
 | |
|             if over:
 | |
|                 break
 | |
|             if self._lives != "unused":
 | |
|                 current = self._ale.lives()
 | |
|                 if current < self._last_lives:
 | |
|                     dead = True
 | |
|                     self._last_lives = current
 | |
|                     break
 | |
|         if not self._repeat:
 | |
|             self._buffer[1][:] = self._buffer[0][:]
 | |
|         self._screen(self._buffer[0])
 | |
|         self._done = over or (self._length and self._step >= self._length)
 | |
|         return self._obs(
 | |
|             total,
 | |
|             is_last=self._done or (dead and self._lives == "reset"),
 | |
|             is_terminal=dead or over,
 | |
|         )
 | |
| 
 | |
|     def reset(self):
 | |
|         self._env.reset()
 | |
|         if self._noops:
 | |
|             for _ in range(self._random.randint(self._noops)):
 | |
|                 _, _, dead, _ = self._env.step(0)
 | |
|                 if dead:
 | |
|                     self._env.reset()
 | |
|         self._last_lives = self._ale.lives()
 | |
|         self._screen(self._buffer[0])
 | |
|         self._buffer[1].fill(0)
 | |
| 
 | |
|         self._done = False
 | |
|         self._step = 0
 | |
|         obs, reward, is_terminal, _ = self._obs(0.0, is_first=True)
 | |
|         return obs
 | |
| 
 | |
|     def _obs(self, reward, is_first=False, is_last=False, is_terminal=False):
 | |
|         np.maximum(self._buffer[0], self._buffer[1], out=self._buffer[0])
 | |
|         image = self._buffer[0]
 | |
|         if image.shape[:2] != self._size:
 | |
|             if self._resize == "opencv":
 | |
|                 image = self._cv2.resize(
 | |
|                     image, self._size, interpolation=self._cv2.INTER_AREA
 | |
|                 )
 | |
|             if self._resize == "pillow":
 | |
|                 image = self._image.fromarray(image)
 | |
|                 image = image.resize(self._size, self._image.NEAREST)
 | |
|                 image = np.array(image)
 | |
|         if self._gray:
 | |
|             weights = [0.299, 0.587, 1 - (0.299 + 0.587)]
 | |
|             image = np.tensordot(image, weights, (-1, 0)).astype(image.dtype)
 | |
|             image = image[:, :, None]
 | |
|         return (
 | |
|             {"image": image, "is_terminal": is_terminal, "is_first": is_first},
 | |
|             reward,
 | |
|             is_last,
 | |
|             {},
 | |
|         )
 | |
| 
 | |
|     def _screen(self, array):
 | |
|         self._ale.getScreenRGB2(array)
 | |
| 
 | |
|     def close(self):
 | |
|         return self._env.close()
 |