diff --git a/dreamer.py b/dreamer.py index 05ec8ed..62f9631 100644 --- a/dreamer.py +++ b/dreamer.py @@ -4,7 +4,7 @@ import os import pathlib import sys -os.environ["MUJOCO_GL"] = "egl" +os.environ["MUJOCO_GL"] = "osmesa" import numpy as np import ruamel.yaml as yaml @@ -15,6 +15,7 @@ import exploration as expl import models import tools import envs.wrappers as wrappers +from parallel import Parallel, Damy import torch from torch import nn @@ -262,6 +263,12 @@ def main(config): make = lambda mode: make_env(config, mode) train_envs = [make("train") for _ in range(config.envs)] eval_envs = [make("eval") for _ in range(config.envs)] + if config.envs > 1: + train_envs = [Parallel(env, "process") for env in train_envs] + eval_envs = [Parallel(env, "process") for env in eval_envs] + else: + train_envs = [Damy(env) for env in train_envs] + eval_envs = [Damy(env) for env in eval_envs] acts = train_envs[0].action_space config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0] diff --git a/envs/minecraft_base.py b/envs/minecraft_base.py index 47a3a56..55f6e29 100644 --- a/envs/minecraft_base.py +++ b/envs/minecraft_base.py @@ -83,7 +83,6 @@ class MinecraftBase(gym.Env): def step(self, action): action = action.copy() - print(self._step, action) action = self._action_values[action] action = self._action(action) following = self._noop_action.copy() diff --git a/parallel.py b/parallel.py new file mode 100644 index 0000000..792c5b0 --- /dev/null +++ b/parallel.py @@ -0,0 +1,208 @@ +import atexit +import os +import sys +import time +import traceback +import enum +from functools import partial as bind + + +class Parallel: + def __init__(self, ctor, strategy): + self.worker = Worker(bind(self._respond, ctor), strategy, state=True) + self.callables = {} + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(name) + try: + if name not in self.callables: + self.callables[name] = self.worker(PMessage.CALLABLE, name)() + if self.callables[name]: + return bind(self.worker, PMessage.CALL, name) + else: + return self.worker(PMessage.READ, name)() + except AttributeError: + raise ValueError(name) + + def __len__(self): + return self.worker(PMessage.CALL, "__len__")() + + def close(self): + self.worker.close() + + @staticmethod + def _respond(ctor, state, message, name, *args, **kwargs): + state = state or ctor + if message == PMessage.CALLABLE: + assert not args and not kwargs, (args, kwargs) + result = callable(getattr(state, name)) + elif message == PMessage.CALL: + result = getattr(state, name)(*args, **kwargs) + elif message == PMessage.READ: + assert not args and not kwargs, (args, kwargs) + result = getattr(state, name) + return state, result + + +class PMessage(enum.Enum): + CALLABLE = 2 + CALL = 3 + READ = 4 + + +class Worker: + initializers = [] + + def __init__(self, fn, strategy="thread", state=False): + if not state: + fn = lambda s, *args, fn=fn, **kwargs: (s, fn(*args, **kwargs)) + inits = self.initializers + self.impl = { + "process": bind(ProcessPipeWorker, initializers=inits), + "daemon": bind(ProcessPipeWorker, initializers=inits, daemon=True), + }[strategy](fn) + self.promise = None + + def __call__(self, *args, **kwargs): + self.promise and self.promise() # Raise previous exception if any. + self.promise = self.impl(*args, **kwargs) + return self.promise + + def wait(self): + return self.impl.wait() + + def close(self): + self.impl.close() + + +class ProcessPipeWorker: + def __init__(self, fn, initializers=(), daemon=False): + import multiprocessing + import cloudpickle + + self._context = multiprocessing.get_context("spawn") + self._pipe, pipe = self._context.Pipe() + fn = cloudpickle.dumps(fn) + initializers = cloudpickle.dumps(initializers) + self._process = self._context.Process( + target=self._loop, args=(pipe, fn, initializers), daemon=daemon + ) + self._process.start() + self._nextid = 0 + self._results = {} + assert self._submit(Message.OK)() + atexit.register(self.close) + + def __call__(self, *args, **kwargs): + return self._submit(Message.RUN, (args, kwargs)) + + def wait(self): + pass + + def close(self): + try: + self._pipe.send((Message.STOP, self._nextid, None)) + self._pipe.close() + except (AttributeError, IOError): + pass # The connection was already closed. + try: + self._process.join(0.1) + if self._process.exitcode is None: + try: + os.kill(self._process.pid, 9) + time.sleep(0.1) + except Exception: + pass + except (AttributeError, AssertionError): + pass + + def _submit(self, message, payload=None): + callid = self._nextid + self._nextid += 1 + self._pipe.send((message, callid, payload)) + return Future(self._receive, callid) + + def _receive(self, callid): + while callid not in self._results: + try: + message, callid, payload = self._pipe.recv() + except (OSError, EOFError): + raise RuntimeError("Lost connection to worker.") + if message == Message.ERROR: + raise Exception(payload) + assert message == Message.RESULT, message + self._results[callid] = payload + return self._results.pop(callid) + + @staticmethod + def _loop(pipe, function, initializers): + try: + callid = None + state = None + import cloudpickle + + initializers = cloudpickle.loads(initializers) + function = cloudpickle.loads(function) + [fn() for fn in initializers] + while True: + if not pipe.poll(0.1): + continue # Wake up for keyboard interrupts. + message, callid, payload = pipe.recv() + if message == Message.OK: + pipe.send((Message.RESULT, callid, True)) + elif message == Message.STOP: + return + elif message == Message.RUN: + args, kwargs = payload + state, result = function(state, *args, **kwargs) + pipe.send((Message.RESULT, callid, result)) + else: + raise KeyError(f"Invalid message: {message}") + except (EOFError, KeyboardInterrupt): + return + except Exception: + stacktrace = "".join(traceback.format_exception(*sys.exc_info())) + print(f"Error inside process worker: {stacktrace}.", flush=True) + pipe.send((Message.ERROR, callid, stacktrace)) + return + finally: + try: + pipe.close() + except Exception: + pass + + +class Message(enum.Enum): + OK = 1 + RUN = 2 + RESULT = 3 + STOP = 4 + ERROR = 5 + + +class Future: + def __init__(self, receive, callid): + self._receive = receive + self._callid = callid + self._result = None + self._complete = False + + def __call__(self): + if not self._complete: + self._result = self._receive(self._callid) + self._complete = True + return self._result + +class Damy: + def __init__(self, env): + self._env = env + + def __getattr__(self, name): + return getattr(self._env, name) + + def step(self, action): + return lambda :self._env.step(action) + + def reset(self): + return lambda :self._env.reset() \ No newline at end of file diff --git a/tools.py b/tools.py index 4be88ef..4b14efa 100644 --- a/tools.py +++ b/tools.py @@ -138,6 +138,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s if done.any(): indices = [index for index, d in enumerate(done) if d] results = [envs[i].reset() for i in indices] + results = [r() for r in results] for i in indices: t = results[i].copy() t = {k: convert(v) for k, v in t.items()} @@ -161,6 +162,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s assert len(action) == len(envs) # step envs results = [e.step(a) for e, a in zip(envs, action)] + results = [r() for r in results] obs, reward, done = zip(*[p[:3] for p in results]) obs = list(obs) reward = list(reward)