Enable venvs.reset() concurrent execution (#517)

- change the internal API name of worker: send_action -> send, get_result -> recv (align with envpool)
- add a timing test for venvs.reset() to make sure the concurrent execution
- change venvs.reset() logic

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
Chengqi Duan 2022-02-08 00:40:01 +08:00 committed by GitHub
parent cd7654bfd5
commit 9c100e0705
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 85 additions and 48 deletions

View File

@ -48,6 +48,7 @@ preprocess
repo repo
ReLU ReLU
namespace namespace
recv
th th
utils utils
NaN NaN

View File

@ -81,6 +81,7 @@ class MyTestEnv(gym.Env):
def reset(self, state=0): def reset(self, state=0):
self.done = False self.done = False
self.do_sleep()
self.index = state self.index = state
return self._get_state() return self._get_state()
@ -116,16 +117,19 @@ class MyTestEnv(gym.Env):
else: else:
return np.array([self.index], dtype=np.float32) return np.array([self.index], dtype=np.float32)
def do_sleep(self):
if self.sleep > 0:
sleep_time = random.random() if self.random_sleep else 1
sleep_time *= self.sleep
time.sleep(sleep_time)
def step(self, action): def step(self, action):
self.steps += 1 self.steps += 1
if self._md_action: if self._md_action:
action = action[0] action = action[0]
if self.done: if self.done:
raise ValueError('step after done !!!') raise ValueError('step after done !!!')
if self.sleep > 0: self.do_sleep()
sleep_time = random.random() if self.random_sleep else 1
sleep_time *= self.sleep
time.sleep(sleep_time)
if self.index == self.size: if self.index == self.size:
self.done = True self.done = True
return self._get_state(), self._get_reward(), self.done, {} return self._get_state(), self._get_reward(), self.done, {}

View File

@ -96,7 +96,12 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7):
for cls in test_cls: for cls in test_cls:
pass_check = 1 pass_check = 1
v = cls(env_fns, wait_num=num - 1, timeout=timeout) v = cls(env_fns, wait_num=num - 1, timeout=timeout)
t = time.time()
v.reset() v.reset()
t = time.time() - t
print(f"{cls} reset {t}")
if t > sleep * 9: # huge than maximum sleep time (7 sleep)
pass_check = 0
expect_result = [ expect_result = [
[0, 1], [0, 1],
[0, 1, 2], [0, 1, 2],

13
tianshou/env/venvs.py vendored
View File

@ -209,7 +209,10 @@ class BaseVectorEnv(gym.Env):
id = self._wrap_id(id) id = self._wrap_id(id)
if self.is_async: if self.is_async:
self._assert_id(id) self._assert_id(id)
obs_list = [self.workers[i].reset() for i in id] # send(None) == reset() in worker
for i in id:
self.workers[i].send(None)
obs_list = [self.workers[i].recv() for i in id]
try: try:
obs = np.stack(obs_list) obs = np.stack(obs_list)
except ValueError: # different len(obs) except ValueError: # different len(obs)
@ -258,10 +261,10 @@ class BaseVectorEnv(gym.Env):
if not self.is_async: if not self.is_async:
assert len(action) == len(id) assert len(action) == len(id)
for i, j in enumerate(id): for i, j in enumerate(id):
self.workers[j].send_action(action[i]) self.workers[j].send(action[i])
result = [] result = []
for j in id: for j in id:
obs, rew, done, info = self.workers[j].get_result() obs, rew, done, info = self.workers[j].recv()
info["env_id"] = j info["env_id"] = j
result.append((obs, rew, done, info)) result.append((obs, rew, done, info))
else: else:
@ -269,7 +272,7 @@ class BaseVectorEnv(gym.Env):
self._assert_id(id) self._assert_id(id)
assert len(action) == len(id) assert len(action) == len(id)
for act, env_id in zip(action, id): for act, env_id in zip(action, id):
self.workers[env_id].send_action(act) self.workers[env_id].send(act)
self.waiting_conn.append(self.workers[env_id]) self.waiting_conn.append(self.workers[env_id])
self.waiting_id.append(env_id) self.waiting_id.append(env_id)
self.ready_id = [x for x in self.ready_id if x not in id] self.ready_id = [x for x in self.ready_id if x not in id]
@ -283,7 +286,7 @@ class BaseVectorEnv(gym.Env):
waiting_index = self.waiting_conn.index(conn) waiting_index = self.waiting_conn.index(conn)
self.waiting_conn.pop(waiting_index) self.waiting_conn.pop(waiting_index)
env_id = self.waiting_id.pop(waiting_index) env_id = self.waiting_id.pop(waiting_index)
obs, rew, done, info = conn.get_result() obs, rew, done, info = conn.recv()
info["env_id"] = env_id info["env_id"] = env_id
result.append((obs, rew, done, info)) result.append((obs, rew, done, info))
self.ready_id.append(env_id) self.ready_id.append(env_id)

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, List, Optional, Tuple, Union
import gym import gym
import numpy as np import numpy as np
@ -23,28 +23,41 @@ class EnvWorker(ABC):
pass pass
@abstractmethod @abstractmethod
def reset(self) -> Any: def send(self, action: Optional[np.ndarray]) -> None:
"""Send action signal to low-level worker.
When action is None, it indicates sending "reset" signal; otherwise
it indicates "step" signal. The paired return value from "recv"
function is determined by such kind of different signal.
"""
pass pass
@abstractmethod def recv(
def send_action(self, action: np.ndarray) -> None: self
pass ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
"""Receive result from low-level worker.
def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: If the last "send" function sends a NULL action, it only returns a
single observation; otherwise it returns a tuple of (obs, rew, done,
info).
"""
return self.result return self.result
def reset(self) -> np.ndarray:
self.send(None)
return self.recv() # type: ignore
def step( def step(
self, action: np.ndarray self, action: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Perform one timestep of the environment's dynamic. """Perform one timestep of the environment's dynamic.
"send_action" and "get_result" are coupled in sync simulation, so "send" and "recv" are coupled in sync simulation, so users only call
typically users only call "step" function. But they can be called "step" function. But they can be called separately in async
separately in async simulation, i.e. someone calls "send_action" first, simulation, i.e. someone calls "send" first, and calls "recv" later.
and calls "get_result" later.
""" """
self.send_action(action) self.send(action)
return self.get_result() return self.recv() # type: ignore
@staticmethod @staticmethod
def wait( def wait(

View File

@ -29,8 +29,11 @@ class DummyEnvWorker(EnvWorker):
# Sequential EnvWorker objects are always ready # Sequential EnvWorker objects are always ready
return workers return workers
def send_action(self, action: np.ndarray) -> None: def send(self, action: Optional[np.ndarray]) -> None:
self.result = self.env.step(action) if action is None:
self.result = self.env.reset()
else:
self.result = self.env.step(action)
def seed(self, seed: Optional[int] = None) -> List[int]: def seed(self, seed: Optional[int] = None) -> List[int]:
super().seed(seed) super().seed(seed)

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, List, Optional, Tuple, Union
import gym import gym
import numpy as np import numpy as np
@ -44,11 +44,16 @@ class RayEnvWorker(EnvWorker):
ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout) ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout)
return [workers[results.index(result)] for result in ready_results] return [workers[results.index(result)] for result in ready_results]
def send_action(self, action: np.ndarray) -> None: def send(self, action: Optional[np.ndarray]) -> None:
# self.action is actually a handle # self.action is actually a handle
self.result = self.env.step.remote(action) if action is None:
self.result = self.env.reset.remote()
else:
self.result = self.env.step.remote(action)
def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: def recv(
self
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
return ray.get(self.result) return ray.get(self.result)
def seed(self, seed: Optional[int] = None) -> List[int]: def seed(self, seed: Optional[int] = None) -> List[int]:

View File

@ -86,17 +86,17 @@ def _worker(
p.close() p.close()
break break
if cmd == "step": if cmd == "step":
obs, reward, done, info = env.step(data) if data is None: # reset
obs = env.reset()
else:
obs, reward, done, info = env.step(data)
if obs_bufs is not None: if obs_bufs is not None:
_encode_obs(obs, obs_bufs) _encode_obs(obs, obs_bufs)
obs = None obs = None
p.send((obs, reward, done, info)) if data is None:
elif cmd == "reset": p.send(obs)
obs = env.reset() else:
if obs_bufs is not None: p.send((obs, reward, done, info))
_encode_obs(obs, obs_bufs)
obs = None
p.send(obs)
elif cmd == "close": elif cmd == "close":
p.send(env.close()) p.send(env.close())
p.close() p.close()
@ -140,6 +140,7 @@ class SubprocEnvWorker(EnvWorker):
self.process = Process(target=_worker, args=args, daemon=True) self.process = Process(target=_worker, args=args, daemon=True)
self.process.start() self.process.start()
self.child_remote.close() self.child_remote.close()
self.is_reset = False
super().__init__(env_fn) super().__init__(env_fn)
def get_env_attr(self, key: str) -> Any: def get_env_attr(self, key: str) -> Any:
@ -165,13 +166,6 @@ class SubprocEnvWorker(EnvWorker):
return decode_obs(self.buffer) return decode_obs(self.buffer)
def reset(self) -> Any:
self.parent_remote.send(["reset", None])
obs = self.parent_remote.recv()
if self.share_memory:
obs = self._decode_obs()
return obs
@staticmethod @staticmethod
def wait( # type: ignore def wait( # type: ignore
workers: List["SubprocEnvWorker"], workers: List["SubprocEnvWorker"],
@ -192,14 +186,23 @@ class SubprocEnvWorker(EnvWorker):
remain_conns = [conn for conn in remain_conns if conn not in ready_conns] remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
return [workers[conns.index(con)] for con in ready_conns] return [workers[conns.index(con)] for con in ready_conns]
def send_action(self, action: np.ndarray) -> None: def send(self, action: Optional[np.ndarray]) -> None:
self.parent_remote.send(["step", action]) self.parent_remote.send(["step", action])
def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: def recv(
obs, rew, done, info = self.parent_remote.recv() self
if self.share_memory: ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
obs = self._decode_obs() result = self.parent_remote.recv()
return obs, rew, done, info if isinstance(result, tuple):
obs, rew, done, info = result
if self.share_memory:
obs = self._decode_obs()
return obs, rew, done, info
else:
obs = result
if self.share_memory:
obs = self._decode_obs()
return obs
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
super().seed(seed) super().seed(seed)