fix vecenv action_space randomness (#300)
This commit is contained in:
		
							parent
							
								
									f22b539761
								
							
						
					
					
						commit
						31e7f445d1
					
				@ -33,7 +33,6 @@ def get_args():
 | 
			
		||||
    parser.add_argument('--step-per-epoch', type=int, default=40000)
 | 
			
		||||
    parser.add_argument('--step-per-collect', type=int, default=4)
 | 
			
		||||
    parser.add_argument('--update-per-step', type=float, default=0.25)
 | 
			
		||||
    parser.add_argument('--update-per-step', type=int, default=1)
 | 
			
		||||
    parser.add_argument('--pre-collect-step', type=int, default=10000)
 | 
			
		||||
    parser.add_argument('--batch-size', type=int, default=256)
 | 
			
		||||
    parser.add_argument('--hidden-sizes', type=int,
 | 
			
		||||
 | 
			
		||||
@ -454,7 +454,7 @@ class ReplayBufferManager(ReplayBuffer):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, buffer_list: List[ReplayBuffer]) -> None:
 | 
			
		||||
        self.buffer_num = len(buffer_list)
 | 
			
		||||
        self.buffers = np.array(buffer_list)
 | 
			
		||||
        self.buffers = np.array(buffer_list, dtype=np.object)
 | 
			
		||||
        offset, size = [], 0
 | 
			
		||||
        buffer_type = type(self.buffers[0])
 | 
			
		||||
        kwargs = self.buffers[0].options
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										53
									
								
								tianshou/env/venvs.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										53
									
								
								tianshou/env/venvs.py
									
									
									
									
										vendored
									
									
								
							@ -43,8 +43,7 @@ class BaseVectorEnv(gym.Env):
 | 
			
		||||
 | 
			
		||||
        Otherwise, the outputs of these envs may be the same with each other.
 | 
			
		||||
 | 
			
		||||
    :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith
 | 
			
		||||
        env.
 | 
			
		||||
    :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith env.
 | 
			
		||||
    :param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a
 | 
			
		||||
        worker which contains the i-th env.
 | 
			
		||||
    :param int wait_num: use in asynchronous simulation if the time cost of
 | 
			
		||||
@ -75,13 +74,11 @@ class BaseVectorEnv(gym.Env):
 | 
			
		||||
 | 
			
		||||
        self.env_num = len(env_fns)
 | 
			
		||||
        self.wait_num = wait_num or len(env_fns)
 | 
			
		||||
        assert (
 | 
			
		||||
            1 <= self.wait_num <= len(env_fns)
 | 
			
		||||
        ), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
 | 
			
		||||
        assert 1 <= self.wait_num <= len(env_fns), \
 | 
			
		||||
            f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
 | 
			
		||||
        self.timeout = timeout
 | 
			
		||||
        assert (
 | 
			
		||||
            self.timeout is None or self.timeout > 0
 | 
			
		||||
        ), f"timeout is {timeout}, it should be positive if provided!"
 | 
			
		||||
        assert self.timeout is None or self.timeout > 0, \
 | 
			
		||||
            f"timeout is {timeout}, it should be positive if provided!"
 | 
			
		||||
        self.is_async = self.wait_num != len(env_fns) or timeout is not None
 | 
			
		||||
        self.waiting_conn: List[EnvWorker] = []
 | 
			
		||||
        # environments in self.ready_id is actually ready
 | 
			
		||||
@ -94,9 +91,8 @@ class BaseVectorEnv(gym.Env):
 | 
			
		||||
        self.is_closed = False
 | 
			
		||||
 | 
			
		||||
    def _assert_is_not_closed(self) -> None:
 | 
			
		||||
        assert not self.is_closed, (
 | 
			
		||||
            f"Methods of {self.__class__.__name__} cannot be called after "
 | 
			
		||||
            "close.")
 | 
			
		||||
        assert not self.is_closed, \
 | 
			
		||||
            f"Methods of {self.__class__.__name__} cannot be called after close."
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        """Return len(self), which is the number of environments."""
 | 
			
		||||
@ -106,9 +102,8 @@ class BaseVectorEnv(gym.Env):
 | 
			
		||||
        """Switch the attribute getter depending on the key.
 | 
			
		||||
 | 
			
		||||
        Any class who inherits ``gym.Env`` will inherit some attributes, like
 | 
			
		||||
        ``action_space``. However, we would like the attribute lookup to go
 | 
			
		||||
        straight into the worker (in fact, this vector env's action_space is
 | 
			
		||||
        always None).
 | 
			
		||||
        ``action_space``. However, we would like the attribute lookup to go straight
 | 
			
		||||
        into the worker (in fact, this vector env's action_space is always None).
 | 
			
		||||
        """
 | 
			
		||||
        if key in ['metadata', 'reward_range', 'spec', 'action_space',
 | 
			
		||||
                   'observation_space']:  # reserved keys in gym.Env
 | 
			
		||||
@ -119,9 +114,8 @@ class BaseVectorEnv(gym.Env):
 | 
			
		||||
    def __getattr__(self, key: str) -> List[Any]:
 | 
			
		||||
        """Fetch a list of env attributes.
 | 
			
		||||
 | 
			
		||||
        This function tries to retrieve an attribute from each individual
 | 
			
		||||
        wrapped environment, if it does not belong to the wrapping vector
 | 
			
		||||
        environment class.
 | 
			
		||||
        This function tries to retrieve an attribute from each individual wrapped
 | 
			
		||||
        environment, if it does not belong to the wrapping vector environment class.
 | 
			
		||||
        """
 | 
			
		||||
        return [getattr(worker, key) for worker in self.workers]
 | 
			
		||||
 | 
			
		||||
@ -136,12 +130,10 @@ class BaseVectorEnv(gym.Env):
 | 
			
		||||
 | 
			
		||||
    def _assert_id(self, id: List[int]) -> None:
 | 
			
		||||
        for i in id:
 | 
			
		||||
            assert (
 | 
			
		||||
                i not in self.waiting_id
 | 
			
		||||
            ), f"Cannot interact with environment {i} which is stepping now."
 | 
			
		||||
            assert (
 | 
			
		||||
                i in self.ready_id
 | 
			
		||||
            ), f"Can only interact with ready environments {self.ready_id}."
 | 
			
		||||
            assert i not in self.waiting_id, \
 | 
			
		||||
                f"Cannot interact with environment {i} which is stepping now."
 | 
			
		||||
            assert i in self.ready_id, \
 | 
			
		||||
                f"Can only interact with ready environments {self.ready_id}."
 | 
			
		||||
 | 
			
		||||
    def reset(
 | 
			
		||||
        self, id: Optional[Union[int, List[int], np.ndarray]] = None
 | 
			
		||||
@ -178,8 +170,7 @@ class BaseVectorEnv(gym.Env):
 | 
			
		||||
 | 
			
		||||
        :return: A tuple including four items:
 | 
			
		||||
 | 
			
		||||
            * ``obs`` a numpy.ndarray, the agent's observation of current \
 | 
			
		||||
                environments
 | 
			
		||||
            * ``obs`` a numpy.ndarray, the agent's observation of current environments
 | 
			
		||||
            * ``rew`` a numpy.ndarray, the amount of rewards returned after \
 | 
			
		||||
                previous actions
 | 
			
		||||
            * ``done`` a numpy.ndarray, whether these episodes have ended, in \
 | 
			
		||||
@ -294,8 +285,7 @@ class DummyVectorEnv(BaseVectorEnv):
 | 
			
		||||
        wait_num: Optional[int] = None,
 | 
			
		||||
        timeout: Optional[float] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
 | 
			
		||||
        super().__init__(env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SubprocVectorEnv(BaseVectorEnv):
 | 
			
		||||
@ -316,8 +306,7 @@ class SubprocVectorEnv(BaseVectorEnv):
 | 
			
		||||
        def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
 | 
			
		||||
            return SubprocEnvWorker(fn, share_memory=False)
 | 
			
		||||
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
 | 
			
		||||
        super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ShmemVectorEnv(BaseVectorEnv):
 | 
			
		||||
@ -340,8 +329,7 @@ class ShmemVectorEnv(BaseVectorEnv):
 | 
			
		||||
        def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
 | 
			
		||||
            return SubprocEnvWorker(fn, share_memory=True)
 | 
			
		||||
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
 | 
			
		||||
        super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RayVectorEnv(BaseVectorEnv):
 | 
			
		||||
@ -369,5 +357,4 @@ class RayVectorEnv(BaseVectorEnv):
 | 
			
		||||
            ) from e
 | 
			
		||||
        if not ray.is_initialized():
 | 
			
		||||
            ray.init()
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)
 | 
			
		||||
        super().__init__(env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										21
									
								
								tianshou/env/worker/subproc.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										21
									
								
								tianshou/env/worker/subproc.py
									
									
									
									
										vendored
									
									
								
							@ -31,10 +31,7 @@ class ShArray:
 | 
			
		||||
    """Wrapper of multiprocessing Array."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
 | 
			
		||||
        self.arr = Array(
 | 
			
		||||
            _NP_TO_CT[dtype.type],
 | 
			
		||||
            int(np.prod(shape)),
 | 
			
		||||
        )
 | 
			
		||||
        self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
 | 
			
		||||
        self.dtype = dtype
 | 
			
		||||
        self.shape = shape
 | 
			
		||||
 | 
			
		||||
@ -143,10 +140,14 @@ class SubprocEnvWorker(EnvWorker):
 | 
			
		||||
        self.process = Process(target=_worker, args=args, daemon=True)
 | 
			
		||||
        self.process.start()
 | 
			
		||||
        self.child_remote.close()
 | 
			
		||||
        self._seed = None
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, key: str) -> Any:
 | 
			
		||||
        self.parent_remote.send(["getattr", key])
 | 
			
		||||
        return self.parent_remote.recv()
 | 
			
		||||
        result = self.parent_remote.recv()
 | 
			
		||||
        if key == "action_space":  # issue #299
 | 
			
		||||
            result.seed(self._seed)
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
 | 
			
		||||
        def decode_obs(
 | 
			
		||||
@ -185,11 +186,9 @@ class SubprocEnvWorker(EnvWorker):
 | 
			
		||||
                if remain_time <= 0:
 | 
			
		||||
                    break
 | 
			
		||||
            # connection.wait hangs if the list is empty
 | 
			
		||||
            new_ready_conns = connection.wait(
 | 
			
		||||
                remain_conns, timeout=remain_time)
 | 
			
		||||
            new_ready_conns = connection.wait(remain_conns, timeout=remain_time)
 | 
			
		||||
            ready_conns.extend(new_ready_conns)  # type: ignore
 | 
			
		||||
            remain_conns = [
 | 
			
		||||
                conn for conn in remain_conns if conn not in ready_conns]
 | 
			
		||||
            remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
 | 
			
		||||
        return [workers[conns.index(con)] for con in ready_conns]
 | 
			
		||||
 | 
			
		||||
    def send_action(self, action: np.ndarray) -> None:
 | 
			
		||||
@ -205,7 +204,9 @@ class SubprocEnvWorker(EnvWorker):
 | 
			
		||||
 | 
			
		||||
    def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
 | 
			
		||||
        self.parent_remote.send(["seed", seed])
 | 
			
		||||
        return self.parent_remote.recv()
 | 
			
		||||
        result = self.parent_remote.recv()
 | 
			
		||||
        self._seed = result[0] if result is not None else seed
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def render(self, **kwargs: Any) -> Any:
 | 
			
		||||
        self.parent_remote.send(["render", kwargs])
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user