dreamerv3-torch/parallel.py
2023-07-23 22:02:06 +09:00

210 lines
6.2 KiB
Python

import atexit
import os
import sys
import time
import traceback
import enum
from functools import partial as bind
class Parallel:
def __init__(self, ctor, strategy):
self.worker = Worker(bind(self._respond, ctor), strategy, state=True)
self.callables = {}
def __getattr__(self, name):
if name.startswith("_"):
raise AttributeError(name)
try:
if name not in self.callables:
self.callables[name] = self.worker(PMessage.CALLABLE, name)()
if self.callables[name]:
return bind(self.worker, PMessage.CALL, name)
else:
return self.worker(PMessage.READ, name)()
except AttributeError:
raise ValueError(name)
def __len__(self):
return self.worker(PMessage.CALL, "__len__")()
def close(self):
self.worker.close()
@staticmethod
def _respond(ctor, state, message, name, *args, **kwargs):
state = state or ctor
if message == PMessage.CALLABLE:
assert not args and not kwargs, (args, kwargs)
result = callable(getattr(state, name))
elif message == PMessage.CALL:
result = getattr(state, name)(*args, **kwargs)
elif message == PMessage.READ:
assert not args and not kwargs, (args, kwargs)
result = getattr(state, name)
return state, result
class PMessage(enum.Enum):
CALLABLE = 2
CALL = 3
READ = 4
class Worker:
initializers = []
def __init__(self, fn, strategy="thread", state=False):
if not state:
fn = lambda s, *args, fn=fn, **kwargs: (s, fn(*args, **kwargs))
inits = self.initializers
self.impl = {
"process": bind(ProcessPipeWorker, initializers=inits),
"daemon": bind(ProcessPipeWorker, initializers=inits, daemon=True),
}[strategy](fn)
self.promise = None
def __call__(self, *args, **kwargs):
self.promise and self.promise() # Raise previous exception if any.
self.promise = self.impl(*args, **kwargs)
return self.promise
def wait(self):
return self.impl.wait()
def close(self):
self.impl.close()
class ProcessPipeWorker:
def __init__(self, fn, initializers=(), daemon=False):
import multiprocessing
import cloudpickle
self._context = multiprocessing.get_context("spawn")
self._pipe, pipe = self._context.Pipe()
fn = cloudpickle.dumps(fn)
initializers = cloudpickle.dumps(initializers)
self._process = self._context.Process(
target=self._loop, args=(pipe, fn, initializers), daemon=daemon
)
self._process.start()
self._nextid = 0
self._results = {}
assert self._submit(Message.OK)()
atexit.register(self.close)
def __call__(self, *args, **kwargs):
return self._submit(Message.RUN, (args, kwargs))
def wait(self):
pass
def close(self):
try:
self._pipe.send((Message.STOP, self._nextid, None))
self._pipe.close()
except (AttributeError, IOError):
pass # The connection was already closed.
try:
self._process.join(0.1)
if self._process.exitcode is None:
try:
os.kill(self._process.pid, 9)
time.sleep(0.1)
except Exception:
pass
except (AttributeError, AssertionError):
pass
def _submit(self, message, payload=None):
callid = self._nextid
self._nextid += 1
self._pipe.send((message, callid, payload))
return Future(self._receive, callid)
def _receive(self, callid):
while callid not in self._results:
try:
message, callid, payload = self._pipe.recv()
except (OSError, EOFError):
raise RuntimeError("Lost connection to worker.")
if message == Message.ERROR:
raise Exception(payload)
assert message == Message.RESULT, message
self._results[callid] = payload
return self._results.pop(callid)
@staticmethod
def _loop(pipe, function, initializers):
try:
callid = None
state = None
import cloudpickle
initializers = cloudpickle.loads(initializers)
function = cloudpickle.loads(function)
[fn() for fn in initializers]
while True:
if not pipe.poll(0.1):
continue # Wake up for keyboard interrupts.
message, callid, payload = pipe.recv()
if message == Message.OK:
pipe.send((Message.RESULT, callid, True))
elif message == Message.STOP:
return
elif message == Message.RUN:
args, kwargs = payload
state, result = function(state, *args, **kwargs)
pipe.send((Message.RESULT, callid, result))
else:
raise KeyError(f"Invalid message: {message}")
except (EOFError, KeyboardInterrupt):
return
except Exception:
stacktrace = "".join(traceback.format_exception(*sys.exc_info()))
print(f"Error inside process worker: {stacktrace}.", flush=True)
pipe.send((Message.ERROR, callid, stacktrace))
return
finally:
try:
pipe.close()
except Exception:
pass
class Message(enum.Enum):
OK = 1
RUN = 2
RESULT = 3
STOP = 4
ERROR = 5
class Future:
def __init__(self, receive, callid):
self._receive = receive
self._callid = callid
self._result = None
self._complete = False
def __call__(self):
if not self._complete:
self._result = self._receive(self._callid)
self._complete = True
return self._result
class Damy:
def __init__(self, env):
self._env = env
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
return lambda: self._env.step(action)
def reset(self):
return lambda: self._env.reset()