Tianshou/tianshou/env/common.py
2020-03-25 14:08:28 +08:00

50 lines
1.2 KiB
Python

import numpy as np
from collections import deque
class EnvWrapper(object):
def __init__(self, env):
self.env = env
def step(self, action):
return self.env.step(action)
def reset(self):
return self.env.reset()
def seed(self, seed=None):
if hasattr(self.env, 'seed'):
return self.env.seed(seed)
def render(self, **kwargs):
if hasattr(self.env, 'render'):
return self.env.render(**kwargs)
def close(self):
self.env.close()
class FrameStack(EnvWrapper):
def __init__(self, env, stack_num):
"""Stack last k frames."""
super().__init__(env)
self.stack_num = stack_num
self._frames = deque([], maxlen=stack_num)
def step(self, action):
obs, reward, done, info = self.env.step(action)
self._frames.append(obs)
return self._get_obs(), reward, done, info
def reset(self):
obs = self.env.reset()
for _ in range(self.stack_num):
self._frames.append(obs)
return self._get_obs()
def _get_obs(self):
try:
return np.concatenate(self._frames, axis=-1)
except ValueError:
return np.stack(self._frames, axis=-1)