Enable venvs.reset() concurrent execution (#517)

- change the internal API name of worker: send_action -> send, get_result -> recv (align with envpool)
- add a timing test for venvs.reset() to make sure the concurrent execution
- change venvs.reset() logic

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
Chengqi Duan 2022-02-08 00:40:01 +08:00 committed by GitHub
parent cd7654bfd5
commit 9c100e0705
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 85 additions and 48 deletions

View File

@ -48,6 +48,7 @@ preprocess
repo
ReLU
namespace
recv
th
utils
NaN

View File

@ -81,6 +81,7 @@ class MyTestEnv(gym.Env):
def reset(self, state=0):
self.done = False
self.do_sleep()
self.index = state
return self._get_state()
@ -116,16 +117,19 @@ class MyTestEnv(gym.Env):
else:
return np.array([self.index], dtype=np.float32)
def do_sleep(self):
if self.sleep > 0:
sleep_time = random.random() if self.random_sleep else 1
sleep_time *= self.sleep
time.sleep(sleep_time)
def step(self, action):
self.steps += 1
if self._md_action:
action = action[0]
if self.done:
raise ValueError('step after done !!!')
if self.sleep > 0:
sleep_time = random.random() if self.random_sleep else 1
sleep_time *= self.sleep
time.sleep(sleep_time)
self.do_sleep()
if self.index == self.size:
self.done = True
return self._get_state(), self._get_reward(), self.done, {}

View File

@ -96,7 +96,12 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7):
for cls in test_cls:
pass_check = 1
v = cls(env_fns, wait_num=num - 1, timeout=timeout)
t = time.time()
v.reset()
t = time.time() - t
print(f"{cls} reset {t}")
if t > sleep * 9: # huge than maximum sleep time (7 sleep)
pass_check = 0
expect_result = [
[0, 1],
[0, 1, 2],

13
tianshou/env/venvs.py vendored
View File

@ -209,7 +209,10 @@ class BaseVectorEnv(gym.Env):
id = self._wrap_id(id)
if self.is_async:
self._assert_id(id)
obs_list = [self.workers[i].reset() for i in id]
# send(None) == reset() in worker
for i in id:
self.workers[i].send(None)
obs_list = [self.workers[i].recv() for i in id]
try:
obs = np.stack(obs_list)
except ValueError: # different len(obs)
@ -258,10 +261,10 @@ class BaseVectorEnv(gym.Env):
if not self.is_async:
assert len(action) == len(id)
for i, j in enumerate(id):
self.workers[j].send_action(action[i])
self.workers[j].send(action[i])
result = []
for j in id:
obs, rew, done, info = self.workers[j].get_result()
obs, rew, done, info = self.workers[j].recv()
info["env_id"] = j
result.append((obs, rew, done, info))
else:
@ -269,7 +272,7 @@ class BaseVectorEnv(gym.Env):
self._assert_id(id)
assert len(action) == len(id)
for act, env_id in zip(action, id):
self.workers[env_id].send_action(act)
self.workers[env_id].send(act)
self.waiting_conn.append(self.workers[env_id])
self.waiting_id.append(env_id)
self.ready_id = [x for x in self.ready_id if x not in id]
@ -283,7 +286,7 @@ class BaseVectorEnv(gym.Env):
waiting_index = self.waiting_conn.index(conn)
self.waiting_conn.pop(waiting_index)
env_id = self.waiting_id.pop(waiting_index)
obs, rew, done, info = conn.get_result()
obs, rew, done, info = conn.recv()
info["env_id"] = env_id
result.append((obs, rew, done, info))
self.ready_id.append(env_id)

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple, Union
import gym
import numpy as np
@ -23,28 +23,41 @@ class EnvWorker(ABC):
pass
@abstractmethod
def reset(self) -> Any:
def send(self, action: Optional[np.ndarray]) -> None:
"""Send action signal to low-level worker.
When action is None, it indicates sending "reset" signal; otherwise
it indicates "step" signal. The paired return value from "recv"
function is determined by such kind of different signal.
"""
pass
@abstractmethod
def send_action(self, action: np.ndarray) -> None:
pass
def recv(
self
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
"""Receive result from low-level worker.
def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
If the last "send" function sends a NULL action, it only returns a
single observation; otherwise it returns a tuple of (obs, rew, done,
info).
"""
return self.result
def reset(self) -> np.ndarray:
self.send(None)
return self.recv() # type: ignore
def step(
self, action: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Perform one timestep of the environment's dynamic.
"send_action" and "get_result" are coupled in sync simulation, so
typically users only call "step" function. But they can be called
separately in async simulation, i.e. someone calls "send_action" first,
and calls "get_result" later.
"send" and "recv" are coupled in sync simulation, so users only call
"step" function. But they can be called separately in async
simulation, i.e. someone calls "send" first, and calls "recv" later.
"""
self.send_action(action)
return self.get_result()
self.send(action)
return self.recv() # type: ignore
@staticmethod
def wait(

View File

@ -29,7 +29,10 @@ class DummyEnvWorker(EnvWorker):
# Sequential EnvWorker objects are always ready
return workers
def send_action(self, action: np.ndarray) -> None:
def send(self, action: Optional[np.ndarray]) -> None:
if action is None:
self.result = self.env.reset()
else:
self.result = self.env.step(action)
def seed(self, seed: Optional[int] = None) -> List[int]:

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple, Union
import gym
import numpy as np
@ -44,11 +44,16 @@ class RayEnvWorker(EnvWorker):
ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout)
return [workers[results.index(result)] for result in ready_results]
def send_action(self, action: np.ndarray) -> None:
def send(self, action: Optional[np.ndarray]) -> None:
# self.action is actually a handle
if action is None:
self.result = self.env.reset.remote()
else:
self.result = self.env.step.remote(action)
def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
def recv(
self
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
return ray.get(self.result)
def seed(self, seed: Optional[int] = None) -> List[int]:

View File

@ -86,17 +86,17 @@ def _worker(
p.close()
break
if cmd == "step":
if data is None: # reset
obs = env.reset()
else:
obs, reward, done, info = env.step(data)
if obs_bufs is not None:
_encode_obs(obs, obs_bufs)
obs = None
p.send((obs, reward, done, info))
elif cmd == "reset":
obs = env.reset()
if obs_bufs is not None:
_encode_obs(obs, obs_bufs)
obs = None
if data is None:
p.send(obs)
else:
p.send((obs, reward, done, info))
elif cmd == "close":
p.send(env.close())
p.close()
@ -140,6 +140,7 @@ class SubprocEnvWorker(EnvWorker):
self.process = Process(target=_worker, args=args, daemon=True)
self.process.start()
self.child_remote.close()
self.is_reset = False
super().__init__(env_fn)
def get_env_attr(self, key: str) -> Any:
@ -165,13 +166,6 @@ class SubprocEnvWorker(EnvWorker):
return decode_obs(self.buffer)
def reset(self) -> Any:
self.parent_remote.send(["reset", None])
obs = self.parent_remote.recv()
if self.share_memory:
obs = self._decode_obs()
return obs
@staticmethod
def wait( # type: ignore
workers: List["SubprocEnvWorker"],
@ -192,14 +186,23 @@ class SubprocEnvWorker(EnvWorker):
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
return [workers[conns.index(con)] for con in ready_conns]
def send_action(self, action: np.ndarray) -> None:
def send(self, action: Optional[np.ndarray]) -> None:
self.parent_remote.send(["step", action])
def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
obs, rew, done, info = self.parent_remote.recv()
def recv(
self
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
result = self.parent_remote.recv()
if isinstance(result, tuple):
obs, rew, done, info = result
if self.share_memory:
obs = self._decode_obs()
return obs, rew, done, info
else:
obs = result
if self.share_memory:
obs = self._decode_obs()
return obs
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
super().seed(seed)