19 lines
622 B
Python
Raw Normal View History

2020-01-17 12:30:26 +08:00
import numpy as np
from abc import ABC, abstractmethod
class AbstractEnvRunner(ABC):
def __init__(self, *, env, model, nsteps):
self.env = env
self.model = model
nenv = env.num_envs
self.batch_ob_shape = (nenv*nsteps,) + env.observation_space.shape
self.obs = np.zeros((nenv,) + env.observation_space.shape, dtype=env.observation_space.dtype.name)
self.obs[:] = env.reset()
self.nsteps = nsteps
self.states = model.initial_state
self.dones = [False for _ in range(nenv)]
@abstractmethod
def run(self):
raise NotImplementedError