diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index b498f45..d2ca217 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -237,7 +237,10 @@ class BaseCollector(ABC): self, gym_reset_kwargs: dict[str, Any] | None = None, ) -> tuple[np.ndarray, np.ndarray]: - """Reset the environments and the initial obs, info, and hidden state of the collector.""" + """Reset the environments and the initial obs, info, and hidden state of the collector. + + :return: The initial observation and info from the (vectorized) environment. + """ gym_reset_kwargs = gym_reset_kwargs or {} obs_NO, info_N = self.env.reset(**gym_reset_kwargs) # TODO: hack, wrap envpool envs such that they don't return a dict @@ -368,16 +371,17 @@ class Collector(BaseCollector): self._is_closed = False self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + @override def close(self) -> None: super().close() self._pre_collect_obs_RO = None self._pre_collect_info_R = None + @override def reset_env( self, gym_reset_kwargs: dict[str, Any] | None = None, ) -> tuple[np.ndarray, np.ndarray]: - """Reset the environments and the initial obs, info, and hidden state of the collector.""" obs_NO, info_N = super().reset_env(gym_reset_kwargs=gym_reset_kwargs) # We assume that R = N when reset is called. # TODO: there is currently no mechanism that ensures this and it's a public method! @@ -457,6 +461,8 @@ class Collector(BaseCollector): ready_env_ids_R = np.arange(self.env_num) elif n_episode is not None: ready_env_ids_R = np.arange(min(self.env_num, n_episode)) + else: + raise ValueError("Either n_step or n_episode should be set.") start_time = time.time() if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: @@ -645,8 +651,8 @@ class Collector(BaseCollector): collect_speed=step_count / collect_time, ) + @staticmethod def _reset_hidden_state_based_on_type( - self, env_ind_local_D: np.ndarray, last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None, ) -> None: @@ -700,21 +706,13 @@ class AsyncCollector(Collector): self._current_action_in_all_envs_EA: np.ndarray = np.empty(self.env_num) self._current_policy_in_all_envs_E: Batch | None = None + @override def reset( self, reset_buffer: bool = True, reset_stats: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, ) -> tuple[np.ndarray, np.ndarray]: - """Reset the environment, statistics, and data needed to start the collection. - - :param reset_buffer: if true, reset the replay buffer attached - to the collector. - :param reset_stats: if true, reset the statistics attached to the collector. - :param gym_reset_kwargs: extra keyword arguments to pass into the environment's - reset function. Defaults to None (extra keyword arguments) - :return: The initial observation and info from the environment. - """ # This sets the _pre_collect attrs result = super().reset( reset_buffer=reset_buffer,