introduced parallel processing for envs

This commit is contained in:
NM512 2023-07-23 21:58:46 +09:00
parent 12e6c68f6b
commit afa5ab988d
4 changed files with 218 additions and 2 deletions

View File

@ -4,7 +4,7 @@ import os
import pathlib import pathlib
import sys import sys
os.environ["MUJOCO_GL"] = "egl" os.environ["MUJOCO_GL"] = "osmesa"
import numpy as np import numpy as np
import ruamel.yaml as yaml import ruamel.yaml as yaml
@ -15,6 +15,7 @@ import exploration as expl
import models import models
import tools import tools
import envs.wrappers as wrappers import envs.wrappers as wrappers
from parallel import Parallel, Damy
import torch import torch
from torch import nn from torch import nn
@ -262,6 +263,12 @@ def main(config):
make = lambda mode: make_env(config, mode) make = lambda mode: make_env(config, mode)
train_envs = [make("train") for _ in range(config.envs)] train_envs = [make("train") for _ in range(config.envs)]
eval_envs = [make("eval") for _ in range(config.envs)] eval_envs = [make("eval") for _ in range(config.envs)]
if config.envs > 1:
train_envs = [Parallel(env, "process") for env in train_envs]
eval_envs = [Parallel(env, "process") for env in eval_envs]
else:
train_envs = [Damy(env) for env in train_envs]
eval_envs = [Damy(env) for env in eval_envs]
acts = train_envs[0].action_space acts = train_envs[0].action_space
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0] config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]

View File

@ -83,7 +83,6 @@ class MinecraftBase(gym.Env):
def step(self, action): def step(self, action):
action = action.copy() action = action.copy()
print(self._step, action)
action = self._action_values[action] action = self._action_values[action]
action = self._action(action) action = self._action(action)
following = self._noop_action.copy() following = self._noop_action.copy()

208
parallel.py Normal file
View File

@ -0,0 +1,208 @@
import atexit
import os
import sys
import time
import traceback
import enum
from functools import partial as bind
class Parallel:
def __init__(self, ctor, strategy):
self.worker = Worker(bind(self._respond, ctor), strategy, state=True)
self.callables = {}
def __getattr__(self, name):
if name.startswith("_"):
raise AttributeError(name)
try:
if name not in self.callables:
self.callables[name] = self.worker(PMessage.CALLABLE, name)()
if self.callables[name]:
return bind(self.worker, PMessage.CALL, name)
else:
return self.worker(PMessage.READ, name)()
except AttributeError:
raise ValueError(name)
def __len__(self):
return self.worker(PMessage.CALL, "__len__")()
def close(self):
self.worker.close()
@staticmethod
def _respond(ctor, state, message, name, *args, **kwargs):
state = state or ctor
if message == PMessage.CALLABLE:
assert not args and not kwargs, (args, kwargs)
result = callable(getattr(state, name))
elif message == PMessage.CALL:
result = getattr(state, name)(*args, **kwargs)
elif message == PMessage.READ:
assert not args and not kwargs, (args, kwargs)
result = getattr(state, name)
return state, result
class PMessage(enum.Enum):
CALLABLE = 2
CALL = 3
READ = 4
class Worker:
initializers = []
def __init__(self, fn, strategy="thread", state=False):
if not state:
fn = lambda s, *args, fn=fn, **kwargs: (s, fn(*args, **kwargs))
inits = self.initializers
self.impl = {
"process": bind(ProcessPipeWorker, initializers=inits),
"daemon": bind(ProcessPipeWorker, initializers=inits, daemon=True),
}[strategy](fn)
self.promise = None
def __call__(self, *args, **kwargs):
self.promise and self.promise() # Raise previous exception if any.
self.promise = self.impl(*args, **kwargs)
return self.promise
def wait(self):
return self.impl.wait()
def close(self):
self.impl.close()
class ProcessPipeWorker:
def __init__(self, fn, initializers=(), daemon=False):
import multiprocessing
import cloudpickle
self._context = multiprocessing.get_context("spawn")
self._pipe, pipe = self._context.Pipe()
fn = cloudpickle.dumps(fn)
initializers = cloudpickle.dumps(initializers)
self._process = self._context.Process(
target=self._loop, args=(pipe, fn, initializers), daemon=daemon
)
self._process.start()
self._nextid = 0
self._results = {}
assert self._submit(Message.OK)()
atexit.register(self.close)
def __call__(self, *args, **kwargs):
return self._submit(Message.RUN, (args, kwargs))
def wait(self):
pass
def close(self):
try:
self._pipe.send((Message.STOP, self._nextid, None))
self._pipe.close()
except (AttributeError, IOError):
pass # The connection was already closed.
try:
self._process.join(0.1)
if self._process.exitcode is None:
try:
os.kill(self._process.pid, 9)
time.sleep(0.1)
except Exception:
pass
except (AttributeError, AssertionError):
pass
def _submit(self, message, payload=None):
callid = self._nextid
self._nextid += 1
self._pipe.send((message, callid, payload))
return Future(self._receive, callid)
def _receive(self, callid):
while callid not in self._results:
try:
message, callid, payload = self._pipe.recv()
except (OSError, EOFError):
raise RuntimeError("Lost connection to worker.")
if message == Message.ERROR:
raise Exception(payload)
assert message == Message.RESULT, message
self._results[callid] = payload
return self._results.pop(callid)
@staticmethod
def _loop(pipe, function, initializers):
try:
callid = None
state = None
import cloudpickle
initializers = cloudpickle.loads(initializers)
function = cloudpickle.loads(function)
[fn() for fn in initializers]
while True:
if not pipe.poll(0.1):
continue # Wake up for keyboard interrupts.
message, callid, payload = pipe.recv()
if message == Message.OK:
pipe.send((Message.RESULT, callid, True))
elif message == Message.STOP:
return
elif message == Message.RUN:
args, kwargs = payload
state, result = function(state, *args, **kwargs)
pipe.send((Message.RESULT, callid, result))
else:
raise KeyError(f"Invalid message: {message}")
except (EOFError, KeyboardInterrupt):
return
except Exception:
stacktrace = "".join(traceback.format_exception(*sys.exc_info()))
print(f"Error inside process worker: {stacktrace}.", flush=True)
pipe.send((Message.ERROR, callid, stacktrace))
return
finally:
try:
pipe.close()
except Exception:
pass
class Message(enum.Enum):
OK = 1
RUN = 2
RESULT = 3
STOP = 4
ERROR = 5
class Future:
def __init__(self, receive, callid):
self._receive = receive
self._callid = callid
self._result = None
self._complete = False
def __call__(self):
if not self._complete:
self._result = self._receive(self._callid)
self._complete = True
return self._result
class Damy:
def __init__(self, env):
self._env = env
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
return lambda :self._env.step(action)
def reset(self):
return lambda :self._env.reset()

View File

@ -138,6 +138,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s
if done.any(): if done.any():
indices = [index for index, d in enumerate(done) if d] indices = [index for index, d in enumerate(done) if d]
results = [envs[i].reset() for i in indices] results = [envs[i].reset() for i in indices]
results = [r() for r in results]
for i in indices: for i in indices:
t = results[i].copy() t = results[i].copy()
t = {k: convert(v) for k, v in t.items()} t = {k: convert(v) for k, v in t.items()}
@ -161,6 +162,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s
assert len(action) == len(envs) assert len(action) == len(envs)
# step envs # step envs
results = [e.step(a) for e, a in zip(envs, action)] results = [e.step(a) for e, a in zip(envs, action)]
results = [r() for r in results]
obs, reward, done = zip(*[p[:3] for p in results]) obs, reward, done = zip(*[p[:3] for p in results])
obs = list(obs) obs = list(obs)
reward = list(reward) reward = list(reward)