Enable venvs.reset() concurrent execution (#517)
- change the internal API name of worker: send_action -> send, get_result -> recv (align with envpool) - add a timing test for venvs.reset() to make sure the concurrent execution - change venvs.reset() logic Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
parent
cd7654bfd5
commit
9c100e0705
@ -48,6 +48,7 @@ preprocess
|
||||
repo
|
||||
ReLU
|
||||
namespace
|
||||
recv
|
||||
th
|
||||
utils
|
||||
NaN
|
||||
|
@ -81,6 +81,7 @@ class MyTestEnv(gym.Env):
|
||||
|
||||
def reset(self, state=0):
|
||||
self.done = False
|
||||
self.do_sleep()
|
||||
self.index = state
|
||||
return self._get_state()
|
||||
|
||||
@ -116,16 +117,19 @@ class MyTestEnv(gym.Env):
|
||||
else:
|
||||
return np.array([self.index], dtype=np.float32)
|
||||
|
||||
def do_sleep(self):
|
||||
if self.sleep > 0:
|
||||
sleep_time = random.random() if self.random_sleep else 1
|
||||
sleep_time *= self.sleep
|
||||
time.sleep(sleep_time)
|
||||
|
||||
def step(self, action):
|
||||
self.steps += 1
|
||||
if self._md_action:
|
||||
action = action[0]
|
||||
if self.done:
|
||||
raise ValueError('step after done !!!')
|
||||
if self.sleep > 0:
|
||||
sleep_time = random.random() if self.random_sleep else 1
|
||||
sleep_time *= self.sleep
|
||||
time.sleep(sleep_time)
|
||||
self.do_sleep()
|
||||
if self.index == self.size:
|
||||
self.done = True
|
||||
return self._get_state(), self._get_reward(), self.done, {}
|
||||
|
@ -96,7 +96,12 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7):
|
||||
for cls in test_cls:
|
||||
pass_check = 1
|
||||
v = cls(env_fns, wait_num=num - 1, timeout=timeout)
|
||||
t = time.time()
|
||||
v.reset()
|
||||
t = time.time() - t
|
||||
print(f"{cls} reset {t}")
|
||||
if t > sleep * 9: # huge than maximum sleep time (7 sleep)
|
||||
pass_check = 0
|
||||
expect_result = [
|
||||
[0, 1],
|
||||
[0, 1, 2],
|
||||
|
13
tianshou/env/venvs.py
vendored
13
tianshou/env/venvs.py
vendored
@ -209,7 +209,10 @@ class BaseVectorEnv(gym.Env):
|
||||
id = self._wrap_id(id)
|
||||
if self.is_async:
|
||||
self._assert_id(id)
|
||||
obs_list = [self.workers[i].reset() for i in id]
|
||||
# send(None) == reset() in worker
|
||||
for i in id:
|
||||
self.workers[i].send(None)
|
||||
obs_list = [self.workers[i].recv() for i in id]
|
||||
try:
|
||||
obs = np.stack(obs_list)
|
||||
except ValueError: # different len(obs)
|
||||
@ -258,10 +261,10 @@ class BaseVectorEnv(gym.Env):
|
||||
if not self.is_async:
|
||||
assert len(action) == len(id)
|
||||
for i, j in enumerate(id):
|
||||
self.workers[j].send_action(action[i])
|
||||
self.workers[j].send(action[i])
|
||||
result = []
|
||||
for j in id:
|
||||
obs, rew, done, info = self.workers[j].get_result()
|
||||
obs, rew, done, info = self.workers[j].recv()
|
||||
info["env_id"] = j
|
||||
result.append((obs, rew, done, info))
|
||||
else:
|
||||
@ -269,7 +272,7 @@ class BaseVectorEnv(gym.Env):
|
||||
self._assert_id(id)
|
||||
assert len(action) == len(id)
|
||||
for act, env_id in zip(action, id):
|
||||
self.workers[env_id].send_action(act)
|
||||
self.workers[env_id].send(act)
|
||||
self.waiting_conn.append(self.workers[env_id])
|
||||
self.waiting_id.append(env_id)
|
||||
self.ready_id = [x for x in self.ready_id if x not in id]
|
||||
@ -283,7 +286,7 @@ class BaseVectorEnv(gym.Env):
|
||||
waiting_index = self.waiting_conn.index(conn)
|
||||
self.waiting_conn.pop(waiting_index)
|
||||
env_id = self.waiting_id.pop(waiting_index)
|
||||
obs, rew, done, info = conn.get_result()
|
||||
obs, rew, done, info = conn.recv()
|
||||
info["env_id"] = env_id
|
||||
result.append((obs, rew, done, info))
|
||||
self.ready_id.append(env_id)
|
||||
|
37
tianshou/env/worker/base.py
vendored
37
tianshou/env/worker/base.py
vendored
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@ -23,28 +23,41 @@ class EnvWorker(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> Any:
|
||||
def send(self, action: Optional[np.ndarray]) -> None:
|
||||
"""Send action signal to low-level worker.
|
||||
|
||||
When action is None, it indicates sending "reset" signal; otherwise
|
||||
it indicates "step" signal. The paired return value from "recv"
|
||||
function is determined by such kind of different signal.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
pass
|
||||
def recv(
|
||||
self
|
||||
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
|
||||
"""Receive result from low-level worker.
|
||||
|
||||
def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
If the last "send" function sends a NULL action, it only returns a
|
||||
single observation; otherwise it returns a tuple of (obs, rew, done,
|
||||
info).
|
||||
"""
|
||||
return self.result
|
||||
|
||||
def reset(self) -> np.ndarray:
|
||||
self.send(None)
|
||||
return self.recv() # type: ignore
|
||||
|
||||
def step(
|
||||
self, action: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Perform one timestep of the environment's dynamic.
|
||||
|
||||
"send_action" and "get_result" are coupled in sync simulation, so
|
||||
typically users only call "step" function. But they can be called
|
||||
separately in async simulation, i.e. someone calls "send_action" first,
|
||||
and calls "get_result" later.
|
||||
"send" and "recv" are coupled in sync simulation, so users only call
|
||||
"step" function. But they can be called separately in async
|
||||
simulation, i.e. someone calls "send" first, and calls "recv" later.
|
||||
"""
|
||||
self.send_action(action)
|
||||
return self.get_result()
|
||||
self.send(action)
|
||||
return self.recv() # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def wait(
|
||||
|
5
tianshou/env/worker/dummy.py
vendored
5
tianshou/env/worker/dummy.py
vendored
@ -29,7 +29,10 @@ class DummyEnvWorker(EnvWorker):
|
||||
# Sequential EnvWorker objects are always ready
|
||||
return workers
|
||||
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
def send(self, action: Optional[np.ndarray]) -> None:
|
||||
if action is None:
|
||||
self.result = self.env.reset()
|
||||
else:
|
||||
self.result = self.env.step(action)
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[int]:
|
||||
|
11
tianshou/env/worker/ray.py
vendored
11
tianshou/env/worker/ray.py
vendored
@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@ -44,11 +44,16 @@ class RayEnvWorker(EnvWorker):
|
||||
ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout)
|
||||
return [workers[results.index(result)] for result in ready_results]
|
||||
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
def send(self, action: Optional[np.ndarray]) -> None:
|
||||
# self.action is actually a handle
|
||||
if action is None:
|
||||
self.result = self.env.reset.remote()
|
||||
else:
|
||||
self.result = self.env.step.remote(action)
|
||||
|
||||
def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
def recv(
|
||||
self
|
||||
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
|
||||
return ray.get(self.result)
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[int]:
|
||||
|
35
tianshou/env/worker/subproc.py
vendored
35
tianshou/env/worker/subproc.py
vendored
@ -86,17 +86,17 @@ def _worker(
|
||||
p.close()
|
||||
break
|
||||
if cmd == "step":
|
||||
if data is None: # reset
|
||||
obs = env.reset()
|
||||
else:
|
||||
obs, reward, done, info = env.step(data)
|
||||
if obs_bufs is not None:
|
||||
_encode_obs(obs, obs_bufs)
|
||||
obs = None
|
||||
p.send((obs, reward, done, info))
|
||||
elif cmd == "reset":
|
||||
obs = env.reset()
|
||||
if obs_bufs is not None:
|
||||
_encode_obs(obs, obs_bufs)
|
||||
obs = None
|
||||
if data is None:
|
||||
p.send(obs)
|
||||
else:
|
||||
p.send((obs, reward, done, info))
|
||||
elif cmd == "close":
|
||||
p.send(env.close())
|
||||
p.close()
|
||||
@ -140,6 +140,7 @@ class SubprocEnvWorker(EnvWorker):
|
||||
self.process = Process(target=_worker, args=args, daemon=True)
|
||||
self.process.start()
|
||||
self.child_remote.close()
|
||||
self.is_reset = False
|
||||
super().__init__(env_fn)
|
||||
|
||||
def get_env_attr(self, key: str) -> Any:
|
||||
@ -165,13 +166,6 @@ class SubprocEnvWorker(EnvWorker):
|
||||
|
||||
return decode_obs(self.buffer)
|
||||
|
||||
def reset(self) -> Any:
|
||||
self.parent_remote.send(["reset", None])
|
||||
obs = self.parent_remote.recv()
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
return obs
|
||||
|
||||
@staticmethod
|
||||
def wait( # type: ignore
|
||||
workers: List["SubprocEnvWorker"],
|
||||
@ -192,14 +186,23 @@ class SubprocEnvWorker(EnvWorker):
|
||||
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
|
||||
return [workers[conns.index(con)] for con in ready_conns]
|
||||
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
def send(self, action: Optional[np.ndarray]) -> None:
|
||||
self.parent_remote.send(["step", action])
|
||||
|
||||
def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
obs, rew, done, info = self.parent_remote.recv()
|
||||
def recv(
|
||||
self
|
||||
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
|
||||
result = self.parent_remote.recv()
|
||||
if isinstance(result, tuple):
|
||||
obs, rew, done, info = result
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
return obs, rew, done, info
|
||||
else:
|
||||
obs = result
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
return obs
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
|
||||
super().seed(seed)
|
||||
|
Loading…
x
Reference in New Issue
Block a user