fix vecenv action_space randomness (#300)
This commit is contained in:
parent
f22b539761
commit
31e7f445d1
@ -33,7 +33,6 @@ def get_args():
|
|||||||
parser.add_argument('--step-per-epoch', type=int, default=40000)
|
parser.add_argument('--step-per-epoch', type=int, default=40000)
|
||||||
parser.add_argument('--step-per-collect', type=int, default=4)
|
parser.add_argument('--step-per-collect', type=int, default=4)
|
||||||
parser.add_argument('--update-per-step', type=float, default=0.25)
|
parser.add_argument('--update-per-step', type=float, default=0.25)
|
||||||
parser.add_argument('--update-per-step', type=int, default=1)
|
|
||||||
parser.add_argument('--pre-collect-step', type=int, default=10000)
|
parser.add_argument('--pre-collect-step', type=int, default=10000)
|
||||||
parser.add_argument('--batch-size', type=int, default=256)
|
parser.add_argument('--batch-size', type=int, default=256)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
|
|||||||
@ -454,7 +454,7 @@ class ReplayBufferManager(ReplayBuffer):
|
|||||||
|
|
||||||
def __init__(self, buffer_list: List[ReplayBuffer]) -> None:
|
def __init__(self, buffer_list: List[ReplayBuffer]) -> None:
|
||||||
self.buffer_num = len(buffer_list)
|
self.buffer_num = len(buffer_list)
|
||||||
self.buffers = np.array(buffer_list)
|
self.buffers = np.array(buffer_list, dtype=np.object)
|
||||||
offset, size = [], 0
|
offset, size = [], 0
|
||||||
buffer_type = type(self.buffers[0])
|
buffer_type = type(self.buffers[0])
|
||||||
kwargs = self.buffers[0].options
|
kwargs = self.buffers[0].options
|
||||||
|
|||||||
53
tianshou/env/venvs.py
vendored
53
tianshou/env/venvs.py
vendored
@ -43,8 +43,7 @@ class BaseVectorEnv(gym.Env):
|
|||||||
|
|
||||||
Otherwise, the outputs of these envs may be the same with each other.
|
Otherwise, the outputs of these envs may be the same with each other.
|
||||||
|
|
||||||
:param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith
|
:param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith env.
|
||||||
env.
|
|
||||||
:param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a
|
:param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a
|
||||||
worker which contains the i-th env.
|
worker which contains the i-th env.
|
||||||
:param int wait_num: use in asynchronous simulation if the time cost of
|
:param int wait_num: use in asynchronous simulation if the time cost of
|
||||||
@ -75,13 +74,11 @@ class BaseVectorEnv(gym.Env):
|
|||||||
|
|
||||||
self.env_num = len(env_fns)
|
self.env_num = len(env_fns)
|
||||||
self.wait_num = wait_num or len(env_fns)
|
self.wait_num = wait_num or len(env_fns)
|
||||||
assert (
|
assert 1 <= self.wait_num <= len(env_fns), \
|
||||||
1 <= self.wait_num <= len(env_fns)
|
f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
|
||||||
), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
|
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
assert (
|
assert self.timeout is None or self.timeout > 0, \
|
||||||
self.timeout is None or self.timeout > 0
|
f"timeout is {timeout}, it should be positive if provided!"
|
||||||
), f"timeout is {timeout}, it should be positive if provided!"
|
|
||||||
self.is_async = self.wait_num != len(env_fns) or timeout is not None
|
self.is_async = self.wait_num != len(env_fns) or timeout is not None
|
||||||
self.waiting_conn: List[EnvWorker] = []
|
self.waiting_conn: List[EnvWorker] = []
|
||||||
# environments in self.ready_id is actually ready
|
# environments in self.ready_id is actually ready
|
||||||
@ -94,9 +91,8 @@ class BaseVectorEnv(gym.Env):
|
|||||||
self.is_closed = False
|
self.is_closed = False
|
||||||
|
|
||||||
def _assert_is_not_closed(self) -> None:
|
def _assert_is_not_closed(self) -> None:
|
||||||
assert not self.is_closed, (
|
assert not self.is_closed, \
|
||||||
f"Methods of {self.__class__.__name__} cannot be called after "
|
f"Methods of {self.__class__.__name__} cannot be called after close."
|
||||||
"close.")
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return len(self), which is the number of environments."""
|
"""Return len(self), which is the number of environments."""
|
||||||
@ -106,9 +102,8 @@ class BaseVectorEnv(gym.Env):
|
|||||||
"""Switch the attribute getter depending on the key.
|
"""Switch the attribute getter depending on the key.
|
||||||
|
|
||||||
Any class who inherits ``gym.Env`` will inherit some attributes, like
|
Any class who inherits ``gym.Env`` will inherit some attributes, like
|
||||||
``action_space``. However, we would like the attribute lookup to go
|
``action_space``. However, we would like the attribute lookup to go straight
|
||||||
straight into the worker (in fact, this vector env's action_space is
|
into the worker (in fact, this vector env's action_space is always None).
|
||||||
always None).
|
|
||||||
"""
|
"""
|
||||||
if key in ['metadata', 'reward_range', 'spec', 'action_space',
|
if key in ['metadata', 'reward_range', 'spec', 'action_space',
|
||||||
'observation_space']: # reserved keys in gym.Env
|
'observation_space']: # reserved keys in gym.Env
|
||||||
@ -119,9 +114,8 @@ class BaseVectorEnv(gym.Env):
|
|||||||
def __getattr__(self, key: str) -> List[Any]:
|
def __getattr__(self, key: str) -> List[Any]:
|
||||||
"""Fetch a list of env attributes.
|
"""Fetch a list of env attributes.
|
||||||
|
|
||||||
This function tries to retrieve an attribute from each individual
|
This function tries to retrieve an attribute from each individual wrapped
|
||||||
wrapped environment, if it does not belong to the wrapping vector
|
environment, if it does not belong to the wrapping vector environment class.
|
||||||
environment class.
|
|
||||||
"""
|
"""
|
||||||
return [getattr(worker, key) for worker in self.workers]
|
return [getattr(worker, key) for worker in self.workers]
|
||||||
|
|
||||||
@ -136,12 +130,10 @@ class BaseVectorEnv(gym.Env):
|
|||||||
|
|
||||||
def _assert_id(self, id: List[int]) -> None:
|
def _assert_id(self, id: List[int]) -> None:
|
||||||
for i in id:
|
for i in id:
|
||||||
assert (
|
assert i not in self.waiting_id, \
|
||||||
i not in self.waiting_id
|
f"Cannot interact with environment {i} which is stepping now."
|
||||||
), f"Cannot interact with environment {i} which is stepping now."
|
assert i in self.ready_id, \
|
||||||
assert (
|
f"Can only interact with ready environments {self.ready_id}."
|
||||||
i in self.ready_id
|
|
||||||
), f"Can only interact with ready environments {self.ready_id}."
|
|
||||||
|
|
||||||
def reset(
|
def reset(
|
||||||
self, id: Optional[Union[int, List[int], np.ndarray]] = None
|
self, id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||||
@ -178,8 +170,7 @@ class BaseVectorEnv(gym.Env):
|
|||||||
|
|
||||||
:return: A tuple including four items:
|
:return: A tuple including four items:
|
||||||
|
|
||||||
* ``obs`` a numpy.ndarray, the agent's observation of current \
|
* ``obs`` a numpy.ndarray, the agent's observation of current environments
|
||||||
environments
|
|
||||||
* ``rew`` a numpy.ndarray, the amount of rewards returned after \
|
* ``rew`` a numpy.ndarray, the amount of rewards returned after \
|
||||||
previous actions
|
previous actions
|
||||||
* ``done`` a numpy.ndarray, whether these episodes have ended, in \
|
* ``done`` a numpy.ndarray, whether these episodes have ended, in \
|
||||||
@ -294,8 +285,7 @@ class DummyVectorEnv(BaseVectorEnv):
|
|||||||
wait_num: Optional[int] = None,
|
wait_num: Optional[int] = None,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
|
||||||
env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
|
|
||||||
|
|
||||||
|
|
||||||
class SubprocVectorEnv(BaseVectorEnv):
|
class SubprocVectorEnv(BaseVectorEnv):
|
||||||
@ -316,8 +306,7 @@ class SubprocVectorEnv(BaseVectorEnv):
|
|||||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||||
return SubprocEnvWorker(fn, share_memory=False)
|
return SubprocEnvWorker(fn, share_memory=False)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
|
||||||
env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
|
|
||||||
|
|
||||||
|
|
||||||
class ShmemVectorEnv(BaseVectorEnv):
|
class ShmemVectorEnv(BaseVectorEnv):
|
||||||
@ -340,8 +329,7 @@ class ShmemVectorEnv(BaseVectorEnv):
|
|||||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||||
return SubprocEnvWorker(fn, share_memory=True)
|
return SubprocEnvWorker(fn, share_memory=True)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
|
||||||
env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
|
|
||||||
|
|
||||||
|
|
||||||
class RayVectorEnv(BaseVectorEnv):
|
class RayVectorEnv(BaseVectorEnv):
|
||||||
@ -369,5 +357,4 @@ class RayVectorEnv(BaseVectorEnv):
|
|||||||
) from e
|
) from e
|
||||||
if not ray.is_initialized():
|
if not ray.is_initialized():
|
||||||
ray.init()
|
ray.init()
|
||||||
super().__init__(
|
super().__init__(env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)
|
||||||
env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)
|
|
||||||
|
|||||||
21
tianshou/env/worker/subproc.py
vendored
21
tianshou/env/worker/subproc.py
vendored
@ -31,10 +31,7 @@ class ShArray:
|
|||||||
"""Wrapper of multiprocessing Array."""
|
"""Wrapper of multiprocessing Array."""
|
||||||
|
|
||||||
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
|
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
|
||||||
self.arr = Array(
|
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
|
||||||
_NP_TO_CT[dtype.type],
|
|
||||||
int(np.prod(shape)),
|
|
||||||
)
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
|
|
||||||
@ -143,10 +140,14 @@ class SubprocEnvWorker(EnvWorker):
|
|||||||
self.process = Process(target=_worker, args=args, daemon=True)
|
self.process = Process(target=_worker, args=args, daemon=True)
|
||||||
self.process.start()
|
self.process.start()
|
||||||
self.child_remote.close()
|
self.child_remote.close()
|
||||||
|
self._seed = None
|
||||||
|
|
||||||
def __getattr__(self, key: str) -> Any:
|
def __getattr__(self, key: str) -> Any:
|
||||||
self.parent_remote.send(["getattr", key])
|
self.parent_remote.send(["getattr", key])
|
||||||
return self.parent_remote.recv()
|
result = self.parent_remote.recv()
|
||||||
|
if key == "action_space": # issue #299
|
||||||
|
result.seed(self._seed)
|
||||||
|
return result
|
||||||
|
|
||||||
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
|
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
|
||||||
def decode_obs(
|
def decode_obs(
|
||||||
@ -185,11 +186,9 @@ class SubprocEnvWorker(EnvWorker):
|
|||||||
if remain_time <= 0:
|
if remain_time <= 0:
|
||||||
break
|
break
|
||||||
# connection.wait hangs if the list is empty
|
# connection.wait hangs if the list is empty
|
||||||
new_ready_conns = connection.wait(
|
new_ready_conns = connection.wait(remain_conns, timeout=remain_time)
|
||||||
remain_conns, timeout=remain_time)
|
|
||||||
ready_conns.extend(new_ready_conns) # type: ignore
|
ready_conns.extend(new_ready_conns) # type: ignore
|
||||||
remain_conns = [
|
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
|
||||||
conn for conn in remain_conns if conn not in ready_conns]
|
|
||||||
return [workers[conns.index(con)] for con in ready_conns]
|
return [workers[conns.index(con)] for con in ready_conns]
|
||||||
|
|
||||||
def send_action(self, action: np.ndarray) -> None:
|
def send_action(self, action: np.ndarray) -> None:
|
||||||
@ -205,7 +204,9 @@ class SubprocEnvWorker(EnvWorker):
|
|||||||
|
|
||||||
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
|
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
|
||||||
self.parent_remote.send(["seed", seed])
|
self.parent_remote.send(["seed", seed])
|
||||||
return self.parent_remote.recv()
|
result = self.parent_remote.recv()
|
||||||
|
self._seed = result[0] if result is not None else seed
|
||||||
|
return result
|
||||||
|
|
||||||
def render(self, **kwargs: Any) -> Any:
|
def render(self, **kwargs: Any) -> Any:
|
||||||
self.parent_remote.send(["render", kwargs])
|
self.parent_remote.send(["render", kwargs])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user