Enable venvs.reset() concurrent execution (#517)
- change the internal API name of worker: send_action -> send, get_result -> recv (align with envpool) - add a timing test for venvs.reset() to make sure the concurrent execution - change venvs.reset() logic Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
parent
cd7654bfd5
commit
9c100e0705
@ -48,6 +48,7 @@ preprocess
|
|||||||
repo
|
repo
|
||||||
ReLU
|
ReLU
|
||||||
namespace
|
namespace
|
||||||
|
recv
|
||||||
th
|
th
|
||||||
utils
|
utils
|
||||||
NaN
|
NaN
|
||||||
|
@ -81,6 +81,7 @@ class MyTestEnv(gym.Env):
|
|||||||
|
|
||||||
def reset(self, state=0):
|
def reset(self, state=0):
|
||||||
self.done = False
|
self.done = False
|
||||||
|
self.do_sleep()
|
||||||
self.index = state
|
self.index = state
|
||||||
return self._get_state()
|
return self._get_state()
|
||||||
|
|
||||||
@ -116,16 +117,19 @@ class MyTestEnv(gym.Env):
|
|||||||
else:
|
else:
|
||||||
return np.array([self.index], dtype=np.float32)
|
return np.array([self.index], dtype=np.float32)
|
||||||
|
|
||||||
|
def do_sleep(self):
|
||||||
|
if self.sleep > 0:
|
||||||
|
sleep_time = random.random() if self.random_sleep else 1
|
||||||
|
sleep_time *= self.sleep
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
self.steps += 1
|
self.steps += 1
|
||||||
if self._md_action:
|
if self._md_action:
|
||||||
action = action[0]
|
action = action[0]
|
||||||
if self.done:
|
if self.done:
|
||||||
raise ValueError('step after done !!!')
|
raise ValueError('step after done !!!')
|
||||||
if self.sleep > 0:
|
self.do_sleep()
|
||||||
sleep_time = random.random() if self.random_sleep else 1
|
|
||||||
sleep_time *= self.sleep
|
|
||||||
time.sleep(sleep_time)
|
|
||||||
if self.index == self.size:
|
if self.index == self.size:
|
||||||
self.done = True
|
self.done = True
|
||||||
return self._get_state(), self._get_reward(), self.done, {}
|
return self._get_state(), self._get_reward(), self.done, {}
|
||||||
|
@ -96,7 +96,12 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7):
|
|||||||
for cls in test_cls:
|
for cls in test_cls:
|
||||||
pass_check = 1
|
pass_check = 1
|
||||||
v = cls(env_fns, wait_num=num - 1, timeout=timeout)
|
v = cls(env_fns, wait_num=num - 1, timeout=timeout)
|
||||||
|
t = time.time()
|
||||||
v.reset()
|
v.reset()
|
||||||
|
t = time.time() - t
|
||||||
|
print(f"{cls} reset {t}")
|
||||||
|
if t > sleep * 9: # huge than maximum sleep time (7 sleep)
|
||||||
|
pass_check = 0
|
||||||
expect_result = [
|
expect_result = [
|
||||||
[0, 1],
|
[0, 1],
|
||||||
[0, 1, 2],
|
[0, 1, 2],
|
||||||
|
13
tianshou/env/venvs.py
vendored
13
tianshou/env/venvs.py
vendored
@ -209,7 +209,10 @@ class BaseVectorEnv(gym.Env):
|
|||||||
id = self._wrap_id(id)
|
id = self._wrap_id(id)
|
||||||
if self.is_async:
|
if self.is_async:
|
||||||
self._assert_id(id)
|
self._assert_id(id)
|
||||||
obs_list = [self.workers[i].reset() for i in id]
|
# send(None) == reset() in worker
|
||||||
|
for i in id:
|
||||||
|
self.workers[i].send(None)
|
||||||
|
obs_list = [self.workers[i].recv() for i in id]
|
||||||
try:
|
try:
|
||||||
obs = np.stack(obs_list)
|
obs = np.stack(obs_list)
|
||||||
except ValueError: # different len(obs)
|
except ValueError: # different len(obs)
|
||||||
@ -258,10 +261,10 @@ class BaseVectorEnv(gym.Env):
|
|||||||
if not self.is_async:
|
if not self.is_async:
|
||||||
assert len(action) == len(id)
|
assert len(action) == len(id)
|
||||||
for i, j in enumerate(id):
|
for i, j in enumerate(id):
|
||||||
self.workers[j].send_action(action[i])
|
self.workers[j].send(action[i])
|
||||||
result = []
|
result = []
|
||||||
for j in id:
|
for j in id:
|
||||||
obs, rew, done, info = self.workers[j].get_result()
|
obs, rew, done, info = self.workers[j].recv()
|
||||||
info["env_id"] = j
|
info["env_id"] = j
|
||||||
result.append((obs, rew, done, info))
|
result.append((obs, rew, done, info))
|
||||||
else:
|
else:
|
||||||
@ -269,7 +272,7 @@ class BaseVectorEnv(gym.Env):
|
|||||||
self._assert_id(id)
|
self._assert_id(id)
|
||||||
assert len(action) == len(id)
|
assert len(action) == len(id)
|
||||||
for act, env_id in zip(action, id):
|
for act, env_id in zip(action, id):
|
||||||
self.workers[env_id].send_action(act)
|
self.workers[env_id].send(act)
|
||||||
self.waiting_conn.append(self.workers[env_id])
|
self.waiting_conn.append(self.workers[env_id])
|
||||||
self.waiting_id.append(env_id)
|
self.waiting_id.append(env_id)
|
||||||
self.ready_id = [x for x in self.ready_id if x not in id]
|
self.ready_id = [x for x in self.ready_id if x not in id]
|
||||||
@ -283,7 +286,7 @@ class BaseVectorEnv(gym.Env):
|
|||||||
waiting_index = self.waiting_conn.index(conn)
|
waiting_index = self.waiting_conn.index(conn)
|
||||||
self.waiting_conn.pop(waiting_index)
|
self.waiting_conn.pop(waiting_index)
|
||||||
env_id = self.waiting_id.pop(waiting_index)
|
env_id = self.waiting_id.pop(waiting_index)
|
||||||
obs, rew, done, info = conn.get_result()
|
obs, rew, done, info = conn.recv()
|
||||||
info["env_id"] = env_id
|
info["env_id"] = env_id
|
||||||
result.append((obs, rew, done, info))
|
result.append((obs, rew, done, info))
|
||||||
self.ready_id.append(env_id)
|
self.ready_id.append(env_id)
|
||||||
|
37
tianshou/env/worker/base.py
vendored
37
tianshou/env/worker/base.py
vendored
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, List, Optional, Tuple
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -23,28 +23,41 @@ class EnvWorker(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset(self) -> Any:
|
def send(self, action: Optional[np.ndarray]) -> None:
|
||||||
|
"""Send action signal to low-level worker.
|
||||||
|
|
||||||
|
When action is None, it indicates sending "reset" signal; otherwise
|
||||||
|
it indicates "step" signal. The paired return value from "recv"
|
||||||
|
function is determined by such kind of different signal.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
def recv(
|
||||||
def send_action(self, action: np.ndarray) -> None:
|
self
|
||||||
pass
|
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
|
||||||
|
"""Receive result from low-level worker.
|
||||||
|
|
||||||
def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
If the last "send" function sends a NULL action, it only returns a
|
||||||
|
single observation; otherwise it returns a tuple of (obs, rew, done,
|
||||||
|
info).
|
||||||
|
"""
|
||||||
return self.result
|
return self.result
|
||||||
|
|
||||||
|
def reset(self) -> np.ndarray:
|
||||||
|
self.send(None)
|
||||||
|
return self.recv() # type: ignore
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, action: np.ndarray
|
self, action: np.ndarray
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""Perform one timestep of the environment's dynamic.
|
"""Perform one timestep of the environment's dynamic.
|
||||||
|
|
||||||
"send_action" and "get_result" are coupled in sync simulation, so
|
"send" and "recv" are coupled in sync simulation, so users only call
|
||||||
typically users only call "step" function. But they can be called
|
"step" function. But they can be called separately in async
|
||||||
separately in async simulation, i.e. someone calls "send_action" first,
|
simulation, i.e. someone calls "send" first, and calls "recv" later.
|
||||||
and calls "get_result" later.
|
|
||||||
"""
|
"""
|
||||||
self.send_action(action)
|
self.send(action)
|
||||||
return self.get_result()
|
return self.recv() # type: ignore
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wait(
|
def wait(
|
||||||
|
7
tianshou/env/worker/dummy.py
vendored
7
tianshou/env/worker/dummy.py
vendored
@ -29,8 +29,11 @@ class DummyEnvWorker(EnvWorker):
|
|||||||
# Sequential EnvWorker objects are always ready
|
# Sequential EnvWorker objects are always ready
|
||||||
return workers
|
return workers
|
||||||
|
|
||||||
def send_action(self, action: np.ndarray) -> None:
|
def send(self, action: Optional[np.ndarray]) -> None:
|
||||||
self.result = self.env.step(action)
|
if action is None:
|
||||||
|
self.result = self.env.reset()
|
||||||
|
else:
|
||||||
|
self.result = self.env.step(action)
|
||||||
|
|
||||||
def seed(self, seed: Optional[int] = None) -> List[int]:
|
def seed(self, seed: Optional[int] = None) -> List[int]:
|
||||||
super().seed(seed)
|
super().seed(seed)
|
||||||
|
13
tianshou/env/worker/ray.py
vendored
13
tianshou/env/worker/ray.py
vendored
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Callable, List, Optional, Tuple
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -44,11 +44,16 @@ class RayEnvWorker(EnvWorker):
|
|||||||
ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout)
|
ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout)
|
||||||
return [workers[results.index(result)] for result in ready_results]
|
return [workers[results.index(result)] for result in ready_results]
|
||||||
|
|
||||||
def send_action(self, action: np.ndarray) -> None:
|
def send(self, action: Optional[np.ndarray]) -> None:
|
||||||
# self.action is actually a handle
|
# self.action is actually a handle
|
||||||
self.result = self.env.step.remote(action)
|
if action is None:
|
||||||
|
self.result = self.env.reset.remote()
|
||||||
|
else:
|
||||||
|
self.result = self.env.step.remote(action)
|
||||||
|
|
||||||
def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
def recv(
|
||||||
|
self
|
||||||
|
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
|
||||||
return ray.get(self.result)
|
return ray.get(self.result)
|
||||||
|
|
||||||
def seed(self, seed: Optional[int] = None) -> List[int]:
|
def seed(self, seed: Optional[int] = None) -> List[int]:
|
||||||
|
45
tianshou/env/worker/subproc.py
vendored
45
tianshou/env/worker/subproc.py
vendored
@ -86,17 +86,17 @@ def _worker(
|
|||||||
p.close()
|
p.close()
|
||||||
break
|
break
|
||||||
if cmd == "step":
|
if cmd == "step":
|
||||||
obs, reward, done, info = env.step(data)
|
if data is None: # reset
|
||||||
|
obs = env.reset()
|
||||||
|
else:
|
||||||
|
obs, reward, done, info = env.step(data)
|
||||||
if obs_bufs is not None:
|
if obs_bufs is not None:
|
||||||
_encode_obs(obs, obs_bufs)
|
_encode_obs(obs, obs_bufs)
|
||||||
obs = None
|
obs = None
|
||||||
p.send((obs, reward, done, info))
|
if data is None:
|
||||||
elif cmd == "reset":
|
p.send(obs)
|
||||||
obs = env.reset()
|
else:
|
||||||
if obs_bufs is not None:
|
p.send((obs, reward, done, info))
|
||||||
_encode_obs(obs, obs_bufs)
|
|
||||||
obs = None
|
|
||||||
p.send(obs)
|
|
||||||
elif cmd == "close":
|
elif cmd == "close":
|
||||||
p.send(env.close())
|
p.send(env.close())
|
||||||
p.close()
|
p.close()
|
||||||
@ -140,6 +140,7 @@ class SubprocEnvWorker(EnvWorker):
|
|||||||
self.process = Process(target=_worker, args=args, daemon=True)
|
self.process = Process(target=_worker, args=args, daemon=True)
|
||||||
self.process.start()
|
self.process.start()
|
||||||
self.child_remote.close()
|
self.child_remote.close()
|
||||||
|
self.is_reset = False
|
||||||
super().__init__(env_fn)
|
super().__init__(env_fn)
|
||||||
|
|
||||||
def get_env_attr(self, key: str) -> Any:
|
def get_env_attr(self, key: str) -> Any:
|
||||||
@ -165,13 +166,6 @@ class SubprocEnvWorker(EnvWorker):
|
|||||||
|
|
||||||
return decode_obs(self.buffer)
|
return decode_obs(self.buffer)
|
||||||
|
|
||||||
def reset(self) -> Any:
|
|
||||||
self.parent_remote.send(["reset", None])
|
|
||||||
obs = self.parent_remote.recv()
|
|
||||||
if self.share_memory:
|
|
||||||
obs = self._decode_obs()
|
|
||||||
return obs
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wait( # type: ignore
|
def wait( # type: ignore
|
||||||
workers: List["SubprocEnvWorker"],
|
workers: List["SubprocEnvWorker"],
|
||||||
@ -192,14 +186,23 @@ class SubprocEnvWorker(EnvWorker):
|
|||||||
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
|
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
|
||||||
return [workers[conns.index(con)] for con in ready_conns]
|
return [workers[conns.index(con)] for con in ready_conns]
|
||||||
|
|
||||||
def send_action(self, action: np.ndarray) -> None:
|
def send(self, action: Optional[np.ndarray]) -> None:
|
||||||
self.parent_remote.send(["step", action])
|
self.parent_remote.send(["step", action])
|
||||||
|
|
||||||
def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
def recv(
|
||||||
obs, rew, done, info = self.parent_remote.recv()
|
self
|
||||||
if self.share_memory:
|
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
|
||||||
obs = self._decode_obs()
|
result = self.parent_remote.recv()
|
||||||
return obs, rew, done, info
|
if isinstance(result, tuple):
|
||||||
|
obs, rew, done, info = result
|
||||||
|
if self.share_memory:
|
||||||
|
obs = self._decode_obs()
|
||||||
|
return obs, rew, done, info
|
||||||
|
else:
|
||||||
|
obs = result
|
||||||
|
if self.share_memory:
|
||||||
|
obs = self._decode_obs()
|
||||||
|
return obs
|
||||||
|
|
||||||
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
|
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
|
||||||
super().seed(seed)
|
super().seed(seed)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user