69 lines
2.3 KiB
Python
69 lines
2.3 KiB
Python
import numpy as np
|
|
from gym import spaces
|
|
from collections import OrderedDict
|
|
from . import VecEnv
|
|
|
|
class DummyVecEnv(VecEnv):
|
|
def __init__(self, env_fns):
|
|
self.envs = [fn() for fn in env_fns]
|
|
env = self.envs[0]
|
|
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
|
|
shapes, dtypes = {}, {}
|
|
self.keys = []
|
|
obs_space = env.observation_space
|
|
|
|
if isinstance(obs_space, spaces.Dict):
|
|
assert isinstance(obs_space.spaces, OrderedDict)
|
|
subspaces = obs_space.spaces
|
|
else:
|
|
subspaces = {None: obs_space}
|
|
|
|
for key, box in subspaces.items():
|
|
shapes[key] = box.shape
|
|
dtypes[key] = box.dtype
|
|
self.keys.append(key)
|
|
|
|
self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys }
|
|
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
|
|
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
|
self.buf_infos = [{} for _ in range(self.num_envs)]
|
|
self.actions = None
|
|
|
|
def step_async(self, actions):
|
|
self.actions = actions
|
|
|
|
def step_wait(self):
|
|
for e in range(self.num_envs):
|
|
# TODO: why not using multiprocess (in ppo2/run_mujoco.py there is only one env in list envs)
|
|
obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(self.actions[e])
|
|
if self.buf_dones[e]:
|
|
obs = self.envs[e].reset()
|
|
self._save_obs(e, obs)
|
|
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones),
|
|
self.buf_infos.copy())
|
|
|
|
def reset(self):
|
|
for e in range(self.num_envs):
|
|
obs = self.envs[e].reset()
|
|
self._save_obs(e, obs)
|
|
return self._obs_from_buf()
|
|
|
|
def close(self):
|
|
return
|
|
|
|
def render(self, mode='human'):
|
|
return [e.render(mode=mode) for e in self.envs]
|
|
|
|
def _save_obs(self, e, obs):
|
|
for k in self.keys:
|
|
if k is None:
|
|
self.buf_obs[k][e] = obs
|
|
else:
|
|
self.buf_obs[k][e] = obs[k]
|
|
|
|
def _obs_from_buf(self):
|
|
if self.keys==[None]:
|
|
return self.buf_obs[None]
|
|
else:
|
|
return self.buf_obs
|