fix vecenv action_space randomness (#300)
This commit is contained in:
parent
f22b539761
commit
31e7f445d1
@ -33,7 +33,6 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=40000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=4)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.25)
|
||||
parser.add_argument('--update-per-step', type=int, default=1)
|
||||
parser.add_argument('--pre-collect-step', type=int, default=10000)
|
||||
parser.add_argument('--batch-size', type=int, default=256)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
|
@ -454,7 +454,7 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
|
||||
def __init__(self, buffer_list: List[ReplayBuffer]) -> None:
|
||||
self.buffer_num = len(buffer_list)
|
||||
self.buffers = np.array(buffer_list)
|
||||
self.buffers = np.array(buffer_list, dtype=np.object)
|
||||
offset, size = [], 0
|
||||
buffer_type = type(self.buffers[0])
|
||||
kwargs = self.buffers[0].options
|
||||
|
53
tianshou/env/venvs.py
vendored
53
tianshou/env/venvs.py
vendored
@ -43,8 +43,7 @@ class BaseVectorEnv(gym.Env):
|
||||
|
||||
Otherwise, the outputs of these envs may be the same with each other.
|
||||
|
||||
:param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith
|
||||
env.
|
||||
:param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith env.
|
||||
:param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a
|
||||
worker which contains the i-th env.
|
||||
:param int wait_num: use in asynchronous simulation if the time cost of
|
||||
@ -75,13 +74,11 @@ class BaseVectorEnv(gym.Env):
|
||||
|
||||
self.env_num = len(env_fns)
|
||||
self.wait_num = wait_num or len(env_fns)
|
||||
assert (
|
||||
1 <= self.wait_num <= len(env_fns)
|
||||
), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
|
||||
assert 1 <= self.wait_num <= len(env_fns), \
|
||||
f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
|
||||
self.timeout = timeout
|
||||
assert (
|
||||
self.timeout is None or self.timeout > 0
|
||||
), f"timeout is {timeout}, it should be positive if provided!"
|
||||
assert self.timeout is None or self.timeout > 0, \
|
||||
f"timeout is {timeout}, it should be positive if provided!"
|
||||
self.is_async = self.wait_num != len(env_fns) or timeout is not None
|
||||
self.waiting_conn: List[EnvWorker] = []
|
||||
# environments in self.ready_id is actually ready
|
||||
@ -94,9 +91,8 @@ class BaseVectorEnv(gym.Env):
|
||||
self.is_closed = False
|
||||
|
||||
def _assert_is_not_closed(self) -> None:
|
||||
assert not self.is_closed, (
|
||||
f"Methods of {self.__class__.__name__} cannot be called after "
|
||||
"close.")
|
||||
assert not self.is_closed, \
|
||||
f"Methods of {self.__class__.__name__} cannot be called after close."
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self), which is the number of environments."""
|
||||
@ -106,9 +102,8 @@ class BaseVectorEnv(gym.Env):
|
||||
"""Switch the attribute getter depending on the key.
|
||||
|
||||
Any class who inherits ``gym.Env`` will inherit some attributes, like
|
||||
``action_space``. However, we would like the attribute lookup to go
|
||||
straight into the worker (in fact, this vector env's action_space is
|
||||
always None).
|
||||
``action_space``. However, we would like the attribute lookup to go straight
|
||||
into the worker (in fact, this vector env's action_space is always None).
|
||||
"""
|
||||
if key in ['metadata', 'reward_range', 'spec', 'action_space',
|
||||
'observation_space']: # reserved keys in gym.Env
|
||||
@ -119,9 +114,8 @@ class BaseVectorEnv(gym.Env):
|
||||
def __getattr__(self, key: str) -> List[Any]:
|
||||
"""Fetch a list of env attributes.
|
||||
|
||||
This function tries to retrieve an attribute from each individual
|
||||
wrapped environment, if it does not belong to the wrapping vector
|
||||
environment class.
|
||||
This function tries to retrieve an attribute from each individual wrapped
|
||||
environment, if it does not belong to the wrapping vector environment class.
|
||||
"""
|
||||
return [getattr(worker, key) for worker in self.workers]
|
||||
|
||||
@ -136,12 +130,10 @@ class BaseVectorEnv(gym.Env):
|
||||
|
||||
def _assert_id(self, id: List[int]) -> None:
|
||||
for i in id:
|
||||
assert (
|
||||
i not in self.waiting_id
|
||||
), f"Cannot interact with environment {i} which is stepping now."
|
||||
assert (
|
||||
i in self.ready_id
|
||||
), f"Can only interact with ready environments {self.ready_id}."
|
||||
assert i not in self.waiting_id, \
|
||||
f"Cannot interact with environment {i} which is stepping now."
|
||||
assert i in self.ready_id, \
|
||||
f"Can only interact with ready environments {self.ready_id}."
|
||||
|
||||
def reset(
|
||||
self, id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||
@ -178,8 +170,7 @@ class BaseVectorEnv(gym.Env):
|
||||
|
||||
:return: A tuple including four items:
|
||||
|
||||
* ``obs`` a numpy.ndarray, the agent's observation of current \
|
||||
environments
|
||||
* ``obs`` a numpy.ndarray, the agent's observation of current environments
|
||||
* ``rew`` a numpy.ndarray, the amount of rewards returned after \
|
||||
previous actions
|
||||
* ``done`` a numpy.ndarray, whether these episodes have ended, in \
|
||||
@ -294,8 +285,7 @@ class DummyVectorEnv(BaseVectorEnv):
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
|
||||
super().__init__(env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
|
||||
|
||||
|
||||
class SubprocVectorEnv(BaseVectorEnv):
|
||||
@ -316,8 +306,7 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||
return SubprocEnvWorker(fn, share_memory=False)
|
||||
|
||||
super().__init__(
|
||||
env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
|
||||
super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
|
||||
|
||||
|
||||
class ShmemVectorEnv(BaseVectorEnv):
|
||||
@ -340,8 +329,7 @@ class ShmemVectorEnv(BaseVectorEnv):
|
||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||
return SubprocEnvWorker(fn, share_memory=True)
|
||||
|
||||
super().__init__(
|
||||
env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
|
||||
super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
|
||||
|
||||
|
||||
class RayVectorEnv(BaseVectorEnv):
|
||||
@ -369,5 +357,4 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
) from e
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
super().__init__(
|
||||
env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)
|
||||
super().__init__(env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)
|
||||
|
21
tianshou/env/worker/subproc.py
vendored
21
tianshou/env/worker/subproc.py
vendored
@ -31,10 +31,7 @@ class ShArray:
|
||||
"""Wrapper of multiprocessing Array."""
|
||||
|
||||
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
|
||||
self.arr = Array(
|
||||
_NP_TO_CT[dtype.type],
|
||||
int(np.prod(shape)),
|
||||
)
|
||||
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
|
||||
self.dtype = dtype
|
||||
self.shape = shape
|
||||
|
||||
@ -143,10 +140,14 @@ class SubprocEnvWorker(EnvWorker):
|
||||
self.process = Process(target=_worker, args=args, daemon=True)
|
||||
self.process.start()
|
||||
self.child_remote.close()
|
||||
self._seed = None
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
self.parent_remote.send(["getattr", key])
|
||||
return self.parent_remote.recv()
|
||||
result = self.parent_remote.recv()
|
||||
if key == "action_space": # issue #299
|
||||
result.seed(self._seed)
|
||||
return result
|
||||
|
||||
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
|
||||
def decode_obs(
|
||||
@ -185,11 +186,9 @@ class SubprocEnvWorker(EnvWorker):
|
||||
if remain_time <= 0:
|
||||
break
|
||||
# connection.wait hangs if the list is empty
|
||||
new_ready_conns = connection.wait(
|
||||
remain_conns, timeout=remain_time)
|
||||
new_ready_conns = connection.wait(remain_conns, timeout=remain_time)
|
||||
ready_conns.extend(new_ready_conns) # type: ignore
|
||||
remain_conns = [
|
||||
conn for conn in remain_conns if conn not in ready_conns]
|
||||
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
|
||||
return [workers[conns.index(con)] for con in ready_conns]
|
||||
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
@ -205,7 +204,9 @@ class SubprocEnvWorker(EnvWorker):
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
|
||||
self.parent_remote.send(["seed", seed])
|
||||
return self.parent_remote.recv()
|
||||
result = self.parent_remote.recv()
|
||||
self._seed = result[0] if result is not None else seed
|
||||
return result
|
||||
|
||||
def render(self, **kwargs: Any) -> Any:
|
||||
self.parent_remote.send(["render", kwargs])
|
||||
|
Loading…
x
Reference in New Issue
Block a user