fix vecenv action_space randomness (#300)

This commit is contained in:
n+e 2021-03-01 15:44:03 +08:00 committed by GitHub
parent f22b539761
commit 31e7f445d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 45 deletions

View File

@ -33,7 +33,6 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=40000)
parser.add_argument('--step-per-collect', type=int, default=4)
parser.add_argument('--update-per-step', type=float, default=0.25)
parser.add_argument('--update-per-step', type=int, default=1)
parser.add_argument('--pre-collect-step', type=int, default=10000)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--hidden-sizes', type=int,

View File

@ -454,7 +454,7 @@ class ReplayBufferManager(ReplayBuffer):
def __init__(self, buffer_list: List[ReplayBuffer]) -> None:
self.buffer_num = len(buffer_list)
self.buffers = np.array(buffer_list)
self.buffers = np.array(buffer_list, dtype=np.object)
offset, size = [], 0
buffer_type = type(self.buffers[0])
kwargs = self.buffers[0].options

53
tianshou/env/venvs.py vendored
View File

@ -43,8 +43,7 @@ class BaseVectorEnv(gym.Env):
Otherwise, the outputs of these envs may be the same with each other.
:param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith
env.
:param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith env.
:param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a
worker which contains the i-th env.
:param int wait_num: use in asynchronous simulation if the time cost of
@ -75,13 +74,11 @@ class BaseVectorEnv(gym.Env):
self.env_num = len(env_fns)
self.wait_num = wait_num or len(env_fns)
assert (
1 <= self.wait_num <= len(env_fns)
), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
assert 1 <= self.wait_num <= len(env_fns), \
f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
self.timeout = timeout
assert (
self.timeout is None or self.timeout > 0
), f"timeout is {timeout}, it should be positive if provided!"
assert self.timeout is None or self.timeout > 0, \
f"timeout is {timeout}, it should be positive if provided!"
self.is_async = self.wait_num != len(env_fns) or timeout is not None
self.waiting_conn: List[EnvWorker] = []
# environments in self.ready_id is actually ready
@ -94,9 +91,8 @@ class BaseVectorEnv(gym.Env):
self.is_closed = False
def _assert_is_not_closed(self) -> None:
assert not self.is_closed, (
f"Methods of {self.__class__.__name__} cannot be called after "
"close.")
assert not self.is_closed, \
f"Methods of {self.__class__.__name__} cannot be called after close."
def __len__(self) -> int:
"""Return len(self), which is the number of environments."""
@ -106,9 +102,8 @@ class BaseVectorEnv(gym.Env):
"""Switch the attribute getter depending on the key.
Any class who inherits ``gym.Env`` will inherit some attributes, like
``action_space``. However, we would like the attribute lookup to go
straight into the worker (in fact, this vector env's action_space is
always None).
``action_space``. However, we would like the attribute lookup to go straight
into the worker (in fact, this vector env's action_space is always None).
"""
if key in ['metadata', 'reward_range', 'spec', 'action_space',
'observation_space']: # reserved keys in gym.Env
@ -119,9 +114,8 @@ class BaseVectorEnv(gym.Env):
def __getattr__(self, key: str) -> List[Any]:
"""Fetch a list of env attributes.
This function tries to retrieve an attribute from each individual
wrapped environment, if it does not belong to the wrapping vector
environment class.
This function tries to retrieve an attribute from each individual wrapped
environment, if it does not belong to the wrapping vector environment class.
"""
return [getattr(worker, key) for worker in self.workers]
@ -136,12 +130,10 @@ class BaseVectorEnv(gym.Env):
def _assert_id(self, id: List[int]) -> None:
for i in id:
assert (
i not in self.waiting_id
), f"Cannot interact with environment {i} which is stepping now."
assert (
i in self.ready_id
), f"Can only interact with ready environments {self.ready_id}."
assert i not in self.waiting_id, \
f"Cannot interact with environment {i} which is stepping now."
assert i in self.ready_id, \
f"Can only interact with ready environments {self.ready_id}."
def reset(
self, id: Optional[Union[int, List[int], np.ndarray]] = None
@ -178,8 +170,7 @@ class BaseVectorEnv(gym.Env):
:return: A tuple including four items:
* ``obs`` a numpy.ndarray, the agent's observation of current \
environments
* ``obs`` a numpy.ndarray, the agent's observation of current environments
* ``rew`` a numpy.ndarray, the amount of rewards returned after \
previous actions
* ``done`` a numpy.ndarray, whether these episodes have ended, in \
@ -294,8 +285,7 @@ class DummyVectorEnv(BaseVectorEnv):
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
super().__init__(
env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
super().__init__(env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
class SubprocVectorEnv(BaseVectorEnv):
@ -316,8 +306,7 @@ class SubprocVectorEnv(BaseVectorEnv):
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=False)
super().__init__(
env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
class ShmemVectorEnv(BaseVectorEnv):
@ -340,8 +329,7 @@ class ShmemVectorEnv(BaseVectorEnv):
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=True)
super().__init__(
env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
class RayVectorEnv(BaseVectorEnv):
@ -369,5 +357,4 @@ class RayVectorEnv(BaseVectorEnv):
) from e
if not ray.is_initialized():
ray.init()
super().__init__(
env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)
super().__init__(env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)

View File

@ -31,10 +31,7 @@ class ShArray:
"""Wrapper of multiprocessing Array."""
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
self.arr = Array(
_NP_TO_CT[dtype.type],
int(np.prod(shape)),
)
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
self.dtype = dtype
self.shape = shape
@ -143,10 +140,14 @@ class SubprocEnvWorker(EnvWorker):
self.process = Process(target=_worker, args=args, daemon=True)
self.process.start()
self.child_remote.close()
self._seed = None
def __getattr__(self, key: str) -> Any:
self.parent_remote.send(["getattr", key])
return self.parent_remote.recv()
result = self.parent_remote.recv()
if key == "action_space": # issue #299
result.seed(self._seed)
return result
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
def decode_obs(
@ -185,11 +186,9 @@ class SubprocEnvWorker(EnvWorker):
if remain_time <= 0:
break
# connection.wait hangs if the list is empty
new_ready_conns = connection.wait(
remain_conns, timeout=remain_time)
new_ready_conns = connection.wait(remain_conns, timeout=remain_time)
ready_conns.extend(new_ready_conns) # type: ignore
remain_conns = [
conn for conn in remain_conns if conn not in ready_conns]
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
return [workers[conns.index(con)] for con in ready_conns]
def send_action(self, action: np.ndarray) -> None:
@ -205,7 +204,9 @@ class SubprocEnvWorker(EnvWorker):
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
self.parent_remote.send(["seed", seed])
return self.parent_remote.recv()
result = self.parent_remote.recv()
self._seed = result[0] if result is not None else seed
return result
def render(self, **kwargs: Any) -> Any:
self.parent_remote.send(["render", kwargs])