youkaichao a9f9940d17
code refactor for venv (#179)
- Refacor code to remove duplicate code

- Enable async simulation for all vector envs

- Remove `collector.close` and rename `VectorEnv` to `DummyVectorEnv`

The abstraction of vector env changed.

Prior to this pr, each vector env is almost independent.

After this pr, each env is wrapped into a worker, and vector envs differ with their worker type. In fact, users can just use `BaseVectorEnv` with different workers, I keep `SubprocVectorEnv`, `ShmemVectorEnv` for backward compatibility.

Co-authored-by: n+e <463003665@qq.com>
Co-authored-by: magicly <magicly007@gmail.com>
2020-08-19 15:00:24 +08:00

134 lines
4.3 KiB
Python

import cv2
import gym
import numpy as np
from gym.spaces.box import Box
from tianshou.data import Batch
SIZE = 84
FRAME = 4
def create_atari_environment(name=None, sticky_actions=True,
max_episode_steps=2000):
game_version = 'v0' if sticky_actions else 'v4'
name = '{}NoFrameskip-{}'.format(name, game_version)
env = gym.make(name)
env = env.env
env = preprocessing(env, max_episode_steps=max_episode_steps)
return env
def preprocess_fn(obs=None, act=None, rew=None, done=None,
obs_next=None, info=None, policy=None, **kwargs):
if obs_next is not None:
obs_next = np.reshape(obs_next, (-1, *obs_next.shape[2:]))
obs_next = np.moveaxis(obs_next, 0, -1)
obs_next = cv2.resize(obs_next, (SIZE, SIZE))
obs_next = np.asanyarray(obs_next, dtype=np.uint8)
obs_next = np.reshape(obs_next, (-1, FRAME, SIZE, SIZE))
obs_next = np.moveaxis(obs_next, 1, -1)
elif obs is not None:
obs = np.reshape(obs, (-1, *obs.shape[2:]))
obs = np.moveaxis(obs, 0, -1)
obs = cv2.resize(obs, (SIZE, SIZE))
obs = np.asanyarray(obs, dtype=np.uint8)
obs = np.reshape(obs, (-1, FRAME, SIZE, SIZE))
obs = np.moveaxis(obs, 1, -1)
return Batch(obs=obs, act=act, rew=rew, done=done,
obs_next=obs_next, info=info)
class preprocessing(object):
def __init__(self, env, frame_skip=4, terminal_on_life_loss=False,
size=84, max_episode_steps=2000):
self.max_episode_steps = max_episode_steps
self.env = env
self.terminal_on_life_loss = terminal_on_life_loss
self.frame_skip = frame_skip
self.size = size
self.count = 0
obs_dims = self.env.observation_space
self.screen_buffer = [
np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8),
np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8)
]
self.game_over = False
self.lives = 0
@property
def observation_space(self):
return Box(low=0, high=255,
shape=(self.size, self.size, self.frame_skip),
dtype=np.uint8)
def action_space(self):
return self.env.action_space
def reward_range(self):
return self.env.reward_range
def metadata(self):
return self.env.metadata
def close(self):
return self.env.close()
def reset(self):
self.count = 0
self.env.reset()
self.lives = self.env.ale.lives()
self._grayscale_obs(self.screen_buffer[0])
self.screen_buffer[1].fill(0)
return np.array([self._pool_and_resize()
for _ in range(self.frame_skip)])
def render(self, mode='human'):
return self.env.render(mode)
def step(self, action):
total_reward = 0.
observation = []
for t in range(self.frame_skip):
self.count += 1
_, reward, terminal, info = self.env.step(action)
total_reward += reward
if self.terminal_on_life_loss:
lives = self.env.ale.lives()
is_terminal = terminal or lives < self.lives
self.lives = lives
else:
is_terminal = terminal
if is_terminal:
break
elif t >= self.frame_skip - 2:
t_ = t - (self.frame_skip - 2)
self._grayscale_obs(self.screen_buffer[t_])
observation.append(self._pool_and_resize())
if len(observation) == 0:
observation = [self._pool_and_resize()
for _ in range(self.frame_skip)]
while len(observation) > 0 and \
len(observation) < self.frame_skip:
observation.append(observation[-1])
terminal = self.count >= self.max_episode_steps
return np.array(observation), total_reward, \
(terminal or is_terminal), info
def _grayscale_obs(self, output):
self.env.ale.getScreenGrayscale(output)
return output
def _pool_and_resize(self):
if self.frame_skip > 1:
np.maximum(self.screen_buffer[0], self.screen_buffer[1],
out=self.screen_buffer[0])
return self.screen_buffer[0]