fix vecenv action_space randomness (#300)

This commit is contained in:
n+e 2021-03-01 15:44:03 +08:00 committed by GitHub
parent f22b539761
commit 31e7f445d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 45 deletions

View File

@ -33,7 +33,6 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=40000) parser.add_argument('--step-per-epoch', type=int, default=40000)
parser.add_argument('--step-per-collect', type=int, default=4) parser.add_argument('--step-per-collect', type=int, default=4)
parser.add_argument('--update-per-step', type=float, default=0.25) parser.add_argument('--update-per-step', type=float, default=0.25)
parser.add_argument('--update-per-step', type=int, default=1)
parser.add_argument('--pre-collect-step', type=int, default=10000) parser.add_argument('--pre-collect-step', type=int, default=10000)
parser.add_argument('--batch-size', type=int, default=256) parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--hidden-sizes', type=int, parser.add_argument('--hidden-sizes', type=int,

View File

@ -454,7 +454,7 @@ class ReplayBufferManager(ReplayBuffer):
def __init__(self, buffer_list: List[ReplayBuffer]) -> None: def __init__(self, buffer_list: List[ReplayBuffer]) -> None:
self.buffer_num = len(buffer_list) self.buffer_num = len(buffer_list)
self.buffers = np.array(buffer_list) self.buffers = np.array(buffer_list, dtype=np.object)
offset, size = [], 0 offset, size = [], 0
buffer_type = type(self.buffers[0]) buffer_type = type(self.buffers[0])
kwargs = self.buffers[0].options kwargs = self.buffers[0].options

53
tianshou/env/venvs.py vendored
View File

@ -43,8 +43,7 @@ class BaseVectorEnv(gym.Env):
Otherwise, the outputs of these envs may be the same with each other. Otherwise, the outputs of these envs may be the same with each other.
:param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith env.
env.
:param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a :param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a
worker which contains the i-th env. worker which contains the i-th env.
:param int wait_num: use in asynchronous simulation if the time cost of :param int wait_num: use in asynchronous simulation if the time cost of
@ -75,13 +74,11 @@ class BaseVectorEnv(gym.Env):
self.env_num = len(env_fns) self.env_num = len(env_fns)
self.wait_num = wait_num or len(env_fns) self.wait_num = wait_num or len(env_fns)
assert ( assert 1 <= self.wait_num <= len(env_fns), \
1 <= self.wait_num <= len(env_fns) f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
self.timeout = timeout self.timeout = timeout
assert ( assert self.timeout is None or self.timeout > 0, \
self.timeout is None or self.timeout > 0 f"timeout is {timeout}, it should be positive if provided!"
), f"timeout is {timeout}, it should be positive if provided!"
self.is_async = self.wait_num != len(env_fns) or timeout is not None self.is_async = self.wait_num != len(env_fns) or timeout is not None
self.waiting_conn: List[EnvWorker] = [] self.waiting_conn: List[EnvWorker] = []
# environments in self.ready_id is actually ready # environments in self.ready_id is actually ready
@ -94,9 +91,8 @@ class BaseVectorEnv(gym.Env):
self.is_closed = False self.is_closed = False
def _assert_is_not_closed(self) -> None: def _assert_is_not_closed(self) -> None:
assert not self.is_closed, ( assert not self.is_closed, \
f"Methods of {self.__class__.__name__} cannot be called after " f"Methods of {self.__class__.__name__} cannot be called after close."
"close.")
def __len__(self) -> int: def __len__(self) -> int:
"""Return len(self), which is the number of environments.""" """Return len(self), which is the number of environments."""
@ -106,9 +102,8 @@ class BaseVectorEnv(gym.Env):
"""Switch the attribute getter depending on the key. """Switch the attribute getter depending on the key.
Any class who inherits ``gym.Env`` will inherit some attributes, like Any class who inherits ``gym.Env`` will inherit some attributes, like
``action_space``. However, we would like the attribute lookup to go ``action_space``. However, we would like the attribute lookup to go straight
straight into the worker (in fact, this vector env's action_space is into the worker (in fact, this vector env's action_space is always None).
always None).
""" """
if key in ['metadata', 'reward_range', 'spec', 'action_space', if key in ['metadata', 'reward_range', 'spec', 'action_space',
'observation_space']: # reserved keys in gym.Env 'observation_space']: # reserved keys in gym.Env
@ -119,9 +114,8 @@ class BaseVectorEnv(gym.Env):
def __getattr__(self, key: str) -> List[Any]: def __getattr__(self, key: str) -> List[Any]:
"""Fetch a list of env attributes. """Fetch a list of env attributes.
This function tries to retrieve an attribute from each individual This function tries to retrieve an attribute from each individual wrapped
wrapped environment, if it does not belong to the wrapping vector environment, if it does not belong to the wrapping vector environment class.
environment class.
""" """
return [getattr(worker, key) for worker in self.workers] return [getattr(worker, key) for worker in self.workers]
@ -136,12 +130,10 @@ class BaseVectorEnv(gym.Env):
def _assert_id(self, id: List[int]) -> None: def _assert_id(self, id: List[int]) -> None:
for i in id: for i in id:
assert ( assert i not in self.waiting_id, \
i not in self.waiting_id f"Cannot interact with environment {i} which is stepping now."
), f"Cannot interact with environment {i} which is stepping now." assert i in self.ready_id, \
assert ( f"Can only interact with ready environments {self.ready_id}."
i in self.ready_id
), f"Can only interact with ready environments {self.ready_id}."
def reset( def reset(
self, id: Optional[Union[int, List[int], np.ndarray]] = None self, id: Optional[Union[int, List[int], np.ndarray]] = None
@ -178,8 +170,7 @@ class BaseVectorEnv(gym.Env):
:return: A tuple including four items: :return: A tuple including four items:
* ``obs`` a numpy.ndarray, the agent's observation of current \ * ``obs`` a numpy.ndarray, the agent's observation of current environments
environments
* ``rew`` a numpy.ndarray, the amount of rewards returned after \ * ``rew`` a numpy.ndarray, the amount of rewards returned after \
previous actions previous actions
* ``done`` a numpy.ndarray, whether these episodes have ended, in \ * ``done`` a numpy.ndarray, whether these episodes have ended, in \
@ -294,8 +285,7 @@ class DummyVectorEnv(BaseVectorEnv):
wait_num: Optional[int] = None, wait_num: Optional[int] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
) -> None: ) -> None:
super().__init__( super().__init__(env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
class SubprocVectorEnv(BaseVectorEnv): class SubprocVectorEnv(BaseVectorEnv):
@ -316,8 +306,7 @@ class SubprocVectorEnv(BaseVectorEnv):
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=False) return SubprocEnvWorker(fn, share_memory=False)
super().__init__( super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
class ShmemVectorEnv(BaseVectorEnv): class ShmemVectorEnv(BaseVectorEnv):
@ -340,8 +329,7 @@ class ShmemVectorEnv(BaseVectorEnv):
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=True) return SubprocEnvWorker(fn, share_memory=True)
super().__init__( super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
class RayVectorEnv(BaseVectorEnv): class RayVectorEnv(BaseVectorEnv):
@ -369,5 +357,4 @@ class RayVectorEnv(BaseVectorEnv):
) from e ) from e
if not ray.is_initialized(): if not ray.is_initialized():
ray.init() ray.init()
super().__init__( super().__init__(env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)
env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)

View File

@ -31,10 +31,7 @@ class ShArray:
"""Wrapper of multiprocessing Array.""" """Wrapper of multiprocessing Array."""
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None: def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
self.arr = Array( self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
_NP_TO_CT[dtype.type],
int(np.prod(shape)),
)
self.dtype = dtype self.dtype = dtype
self.shape = shape self.shape = shape
@ -143,10 +140,14 @@ class SubprocEnvWorker(EnvWorker):
self.process = Process(target=_worker, args=args, daemon=True) self.process = Process(target=_worker, args=args, daemon=True)
self.process.start() self.process.start()
self.child_remote.close() self.child_remote.close()
self._seed = None
def __getattr__(self, key: str) -> Any: def __getattr__(self, key: str) -> Any:
self.parent_remote.send(["getattr", key]) self.parent_remote.send(["getattr", key])
return self.parent_remote.recv() result = self.parent_remote.recv()
if key == "action_space": # issue #299
result.seed(self._seed)
return result
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]: def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
def decode_obs( def decode_obs(
@ -185,11 +186,9 @@ class SubprocEnvWorker(EnvWorker):
if remain_time <= 0: if remain_time <= 0:
break break
# connection.wait hangs if the list is empty # connection.wait hangs if the list is empty
new_ready_conns = connection.wait( new_ready_conns = connection.wait(remain_conns, timeout=remain_time)
remain_conns, timeout=remain_time)
ready_conns.extend(new_ready_conns) # type: ignore ready_conns.extend(new_ready_conns) # type: ignore
remain_conns = [ remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
conn for conn in remain_conns if conn not in ready_conns]
return [workers[conns.index(con)] for con in ready_conns] return [workers[conns.index(con)] for con in ready_conns]
def send_action(self, action: np.ndarray) -> None: def send_action(self, action: np.ndarray) -> None:
@ -205,7 +204,9 @@ class SubprocEnvWorker(EnvWorker):
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
self.parent_remote.send(["seed", seed]) self.parent_remote.send(["seed", seed])
return self.parent_remote.recv() result = self.parent_remote.recv()
self._seed = result[0] if result is not None else seed
return result
def render(self, **kwargs: Any) -> Any: def render(self, **kwargs: Any) -> Any:
self.parent_remote.send(["render", kwargs]) self.parent_remote.send(["render", kwargs])