introduced parallel processing for envs
This commit is contained in:
parent
12e6c68f6b
commit
afa5ab988d
@ -4,7 +4,7 @@ import os
|
|||||||
import pathlib
|
import pathlib
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
os.environ["MUJOCO_GL"] = "egl"
|
os.environ["MUJOCO_GL"] = "osmesa"
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ruamel.yaml as yaml
|
import ruamel.yaml as yaml
|
||||||
@ -15,6 +15,7 @@ import exploration as expl
|
|||||||
import models
|
import models
|
||||||
import tools
|
import tools
|
||||||
import envs.wrappers as wrappers
|
import envs.wrappers as wrappers
|
||||||
|
from parallel import Parallel, Damy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -262,6 +263,12 @@ def main(config):
|
|||||||
make = lambda mode: make_env(config, mode)
|
make = lambda mode: make_env(config, mode)
|
||||||
train_envs = [make("train") for _ in range(config.envs)]
|
train_envs = [make("train") for _ in range(config.envs)]
|
||||||
eval_envs = [make("eval") for _ in range(config.envs)]
|
eval_envs = [make("eval") for _ in range(config.envs)]
|
||||||
|
if config.envs > 1:
|
||||||
|
train_envs = [Parallel(env, "process") for env in train_envs]
|
||||||
|
eval_envs = [Parallel(env, "process") for env in eval_envs]
|
||||||
|
else:
|
||||||
|
train_envs = [Damy(env) for env in train_envs]
|
||||||
|
eval_envs = [Damy(env) for env in eval_envs]
|
||||||
acts = train_envs[0].action_space
|
acts = train_envs[0].action_space
|
||||||
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]
|
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]
|
||||||
|
|
||||||
|
|||||||
@ -83,7 +83,6 @@ class MinecraftBase(gym.Env):
|
|||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
action = action.copy()
|
action = action.copy()
|
||||||
print(self._step, action)
|
|
||||||
action = self._action_values[action]
|
action = self._action_values[action]
|
||||||
action = self._action(action)
|
action = self._action(action)
|
||||||
following = self._noop_action.copy()
|
following = self._noop_action.copy()
|
||||||
|
|||||||
208
parallel.py
Normal file
208
parallel.py
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
import atexit
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import enum
|
||||||
|
from functools import partial as bind
|
||||||
|
|
||||||
|
|
||||||
|
class Parallel:
|
||||||
|
def __init__(self, ctor, strategy):
|
||||||
|
self.worker = Worker(bind(self._respond, ctor), strategy, state=True)
|
||||||
|
self.callables = {}
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name.startswith("_"):
|
||||||
|
raise AttributeError(name)
|
||||||
|
try:
|
||||||
|
if name not in self.callables:
|
||||||
|
self.callables[name] = self.worker(PMessage.CALLABLE, name)()
|
||||||
|
if self.callables[name]:
|
||||||
|
return bind(self.worker, PMessage.CALL, name)
|
||||||
|
else:
|
||||||
|
return self.worker(PMessage.READ, name)()
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(name)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.worker(PMessage.CALL, "__len__")()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.worker.close()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _respond(ctor, state, message, name, *args, **kwargs):
|
||||||
|
state = state or ctor
|
||||||
|
if message == PMessage.CALLABLE:
|
||||||
|
assert not args and not kwargs, (args, kwargs)
|
||||||
|
result = callable(getattr(state, name))
|
||||||
|
elif message == PMessage.CALL:
|
||||||
|
result = getattr(state, name)(*args, **kwargs)
|
||||||
|
elif message == PMessage.READ:
|
||||||
|
assert not args and not kwargs, (args, kwargs)
|
||||||
|
result = getattr(state, name)
|
||||||
|
return state, result
|
||||||
|
|
||||||
|
|
||||||
|
class PMessage(enum.Enum):
|
||||||
|
CALLABLE = 2
|
||||||
|
CALL = 3
|
||||||
|
READ = 4
|
||||||
|
|
||||||
|
|
||||||
|
class Worker:
|
||||||
|
initializers = []
|
||||||
|
|
||||||
|
def __init__(self, fn, strategy="thread", state=False):
|
||||||
|
if not state:
|
||||||
|
fn = lambda s, *args, fn=fn, **kwargs: (s, fn(*args, **kwargs))
|
||||||
|
inits = self.initializers
|
||||||
|
self.impl = {
|
||||||
|
"process": bind(ProcessPipeWorker, initializers=inits),
|
||||||
|
"daemon": bind(ProcessPipeWorker, initializers=inits, daemon=True),
|
||||||
|
}[strategy](fn)
|
||||||
|
self.promise = None
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
self.promise and self.promise() # Raise previous exception if any.
|
||||||
|
self.promise = self.impl(*args, **kwargs)
|
||||||
|
return self.promise
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
return self.impl.wait()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.impl.close()
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessPipeWorker:
|
||||||
|
def __init__(self, fn, initializers=(), daemon=False):
|
||||||
|
import multiprocessing
|
||||||
|
import cloudpickle
|
||||||
|
|
||||||
|
self._context = multiprocessing.get_context("spawn")
|
||||||
|
self._pipe, pipe = self._context.Pipe()
|
||||||
|
fn = cloudpickle.dumps(fn)
|
||||||
|
initializers = cloudpickle.dumps(initializers)
|
||||||
|
self._process = self._context.Process(
|
||||||
|
target=self._loop, args=(pipe, fn, initializers), daemon=daemon
|
||||||
|
)
|
||||||
|
self._process.start()
|
||||||
|
self._nextid = 0
|
||||||
|
self._results = {}
|
||||||
|
assert self._submit(Message.OK)()
|
||||||
|
atexit.register(self.close)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self._submit(Message.RUN, (args, kwargs))
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
try:
|
||||||
|
self._pipe.send((Message.STOP, self._nextid, None))
|
||||||
|
self._pipe.close()
|
||||||
|
except (AttributeError, IOError):
|
||||||
|
pass # The connection was already closed.
|
||||||
|
try:
|
||||||
|
self._process.join(0.1)
|
||||||
|
if self._process.exitcode is None:
|
||||||
|
try:
|
||||||
|
os.kill(self._process.pid, 9)
|
||||||
|
time.sleep(0.1)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except (AttributeError, AssertionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _submit(self, message, payload=None):
|
||||||
|
callid = self._nextid
|
||||||
|
self._nextid += 1
|
||||||
|
self._pipe.send((message, callid, payload))
|
||||||
|
return Future(self._receive, callid)
|
||||||
|
|
||||||
|
def _receive(self, callid):
|
||||||
|
while callid not in self._results:
|
||||||
|
try:
|
||||||
|
message, callid, payload = self._pipe.recv()
|
||||||
|
except (OSError, EOFError):
|
||||||
|
raise RuntimeError("Lost connection to worker.")
|
||||||
|
if message == Message.ERROR:
|
||||||
|
raise Exception(payload)
|
||||||
|
assert message == Message.RESULT, message
|
||||||
|
self._results[callid] = payload
|
||||||
|
return self._results.pop(callid)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _loop(pipe, function, initializers):
|
||||||
|
try:
|
||||||
|
callid = None
|
||||||
|
state = None
|
||||||
|
import cloudpickle
|
||||||
|
|
||||||
|
initializers = cloudpickle.loads(initializers)
|
||||||
|
function = cloudpickle.loads(function)
|
||||||
|
[fn() for fn in initializers]
|
||||||
|
while True:
|
||||||
|
if not pipe.poll(0.1):
|
||||||
|
continue # Wake up for keyboard interrupts.
|
||||||
|
message, callid, payload = pipe.recv()
|
||||||
|
if message == Message.OK:
|
||||||
|
pipe.send((Message.RESULT, callid, True))
|
||||||
|
elif message == Message.STOP:
|
||||||
|
return
|
||||||
|
elif message == Message.RUN:
|
||||||
|
args, kwargs = payload
|
||||||
|
state, result = function(state, *args, **kwargs)
|
||||||
|
pipe.send((Message.RESULT, callid, result))
|
||||||
|
else:
|
||||||
|
raise KeyError(f"Invalid message: {message}")
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
stacktrace = "".join(traceback.format_exception(*sys.exc_info()))
|
||||||
|
print(f"Error inside process worker: {stacktrace}.", flush=True)
|
||||||
|
pipe.send((Message.ERROR, callid, stacktrace))
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
pipe.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Message(enum.Enum):
|
||||||
|
OK = 1
|
||||||
|
RUN = 2
|
||||||
|
RESULT = 3
|
||||||
|
STOP = 4
|
||||||
|
ERROR = 5
|
||||||
|
|
||||||
|
|
||||||
|
class Future:
|
||||||
|
def __init__(self, receive, callid):
|
||||||
|
self._receive = receive
|
||||||
|
self._callid = callid
|
||||||
|
self._result = None
|
||||||
|
self._complete = False
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
if not self._complete:
|
||||||
|
self._result = self._receive(self._callid)
|
||||||
|
self._complete = True
|
||||||
|
return self._result
|
||||||
|
|
||||||
|
class Damy:
|
||||||
|
def __init__(self, env):
|
||||||
|
self._env = env
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._env, name)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
return lambda :self._env.step(action)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
return lambda :self._env.reset()
|
||||||
2
tools.py
2
tools.py
@ -138,6 +138,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s
|
|||||||
if done.any():
|
if done.any():
|
||||||
indices = [index for index, d in enumerate(done) if d]
|
indices = [index for index, d in enumerate(done) if d]
|
||||||
results = [envs[i].reset() for i in indices]
|
results = [envs[i].reset() for i in indices]
|
||||||
|
results = [r() for r in results]
|
||||||
for i in indices:
|
for i in indices:
|
||||||
t = results[i].copy()
|
t = results[i].copy()
|
||||||
t = {k: convert(v) for k, v in t.items()}
|
t = {k: convert(v) for k, v in t.items()}
|
||||||
@ -161,6 +162,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s
|
|||||||
assert len(action) == len(envs)
|
assert len(action) == len(envs)
|
||||||
# step envs
|
# step envs
|
||||||
results = [e.step(a) for e, a in zip(envs, action)]
|
results = [e.step(a) for e, a in zip(envs, action)]
|
||||||
|
results = [r() for r in results]
|
||||||
obs, reward, done = zip(*[p[:3] for p in results])
|
obs, reward, done = zip(*[p[:3] for p in results])
|
||||||
obs = list(obs)
|
obs = list(obs)
|
||||||
reward = list(reward)
|
reward = list(reward)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user