129 lines
3.9 KiB
Python
129 lines
3.9 KiB
Python
import numpy as np
|
|
|
|
|
|
class Atari:
|
|
|
|
LOCK = None
|
|
|
|
def __init__(
|
|
self, name, action_repeat=4, size=(84, 84), gray=True, noops=0, lives='unused',
|
|
sticky=True, actions='all', length=108000, resize='opencv', seed=None):
|
|
assert size[0] == size[1]
|
|
assert lives in ('unused', 'discount', 'reset'), lives
|
|
assert actions in ('all', 'needed'), actions
|
|
assert resize in ('opencv', 'pillow'), resize
|
|
if self.LOCK is None:
|
|
import multiprocessing as mp
|
|
mp = mp.get_context('spawn')
|
|
self.LOCK = mp.Lock()
|
|
self._resize = resize
|
|
if self._resize == 'opencv':
|
|
import cv2
|
|
self._cv2 = cv2
|
|
if self._resize == 'pillow':
|
|
from PIL import Image
|
|
self._image = Image
|
|
import gym.envs.atari
|
|
if name == 'james_bond':
|
|
name = 'jamesbond'
|
|
self._repeat = action_repeat
|
|
self._size = size
|
|
self._gray = gray
|
|
self._noops = noops
|
|
self._lives = lives
|
|
self._sticky = sticky
|
|
self._length = length
|
|
self._random = np.random.RandomState(seed)
|
|
with self.LOCK:
|
|
self._env = gym.envs.atari.AtariEnv(
|
|
game=name,
|
|
obs_type='image',
|
|
frameskip=1, repeat_action_probability=0.25 if sticky else 0.0,
|
|
full_action_space=(actions == 'all'))
|
|
assert self._env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
|
shape = self._env.observation_space.shape
|
|
self._buffer = [np.zeros(shape, np.uint8) for _ in range(2)]
|
|
self._ale = self._env.unwrapped.ale
|
|
self._last_lives = None
|
|
self._done = True
|
|
self._step = 0
|
|
|
|
@property
|
|
def action_space(self):
|
|
space = self._env.action_space
|
|
space.discrete = True
|
|
return space
|
|
|
|
def step(self, action):
|
|
# if action['reset'] or self._done:
|
|
# with self.LOCK:
|
|
# self._reset()
|
|
# self._done = False
|
|
# self._step = 0
|
|
# return self._obs(0.0, is_first=True)
|
|
total = 0.0
|
|
dead = False
|
|
if len(action.shape) >= 1:
|
|
action = np.argmax(action)
|
|
for repeat in range(self._repeat):
|
|
_, reward, over, info = self._env.step(action)
|
|
self._step += 1
|
|
total += reward
|
|
if repeat == self._repeat - 2:
|
|
self._screen(self._buffer[1])
|
|
if over:
|
|
break
|
|
if self._lives != 'unused':
|
|
current = self._ale.lives()
|
|
if current < self._last_lives:
|
|
dead = True
|
|
self._last_lives = current
|
|
break
|
|
if not self._repeat:
|
|
self._buffer[1][:] = self._buffer[0][:]
|
|
self._screen(self._buffer[0])
|
|
self._done = over or (self._length and self._step >= self._length) or dead
|
|
return self._obs(
|
|
total,
|
|
is_last=self._done or (dead and self._lives == 'reset'),
|
|
is_terminal=dead or over)
|
|
|
|
def reset(self):
|
|
self._env.reset()
|
|
if self._noops:
|
|
for _ in range(self._random.randint(self._noops)):
|
|
_, _, dead, _ = self._env.step(0)
|
|
if dead:
|
|
self._env.reset()
|
|
self._last_lives = self._ale.lives()
|
|
self._screen(self._buffer[0])
|
|
self._buffer[1].fill(0)
|
|
|
|
self._done = False
|
|
self._step = 0
|
|
obs, reward, is_terminal, _ = self._obs(0.0, is_first=True)
|
|
return obs
|
|
|
|
def _obs(self, reward, is_first=False, is_last=False, is_terminal=False):
|
|
np.maximum(self._buffer[0], self._buffer[1], out=self._buffer[0])
|
|
image = self._buffer[0]
|
|
if image.shape[:2] != self._size:
|
|
if self._resize == 'opencv':
|
|
image = self._cv2.resize(
|
|
image, self._size, interpolation=self._cv2.INTER_AREA)
|
|
if self._resize == 'pillow':
|
|
image = self._image.fromarray(image)
|
|
image = image.resize(self._size, self._image.NEAREST)
|
|
image = np.array(image)
|
|
if self._gray:
|
|
weights = [0.299, 0.587, 1 - (0.299 + 0.587)]
|
|
image = np.tensordot(image, weights, (-1, 0)).astype(image.dtype)
|
|
image = image[:, :, None]
|
|
return {'image':image, 'is_terminal':is_terminal}, reward, is_last, {}
|
|
|
|
def _screen(self, array):
|
|
self._ale.getScreenRGB2(array)
|
|
|
|
def close(self):
|
|
return self._env.close()
|