diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 2637df4..74193e9 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -3,7 +3,7 @@ import multiprocessing import time from collections import OrderedDict from collections.abc import Callable -from multiprocessing import Pipe, connection +from multiprocessing import connection from multiprocessing.context import BaseContext from typing import Any, Literal @@ -149,11 +149,11 @@ class SubprocEnvWorker(EnvWorker): share_memory: bool = False, context: BaseContext | Literal["fork", "spawn"] | None = None, ) -> None: - self.parent_remote, self.child_remote = Pipe() - self.share_memory = share_memory - self.buffer: dict | tuple | ShArray | None = None if not isinstance(context, BaseContext): context = multiprocessing.get_context(context) + self.parent_remote, self.child_remote = context.Pipe() + self.share_memory = share_memory + self.buffer: dict | tuple | ShArray | None = None assert hasattr(context, "Process") # for mypy if self.share_memory: dummy = env_fn()