Collector: move @override, removed docstrings from overridden methods

This commit is contained in:
Michael Panchenko 2024-05-05 16:01:52 +02:00
parent 26a6cca76e
commit 82f425e9fe

View File

@ -237,7 +237,10 @@ class BaseCollector(ABC):
self,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Reset the environments and the initial obs, info, and hidden state of the collector."""
"""Reset the environments and the initial obs, info, and hidden state of the collector.
:return: The initial observation and info from the (vectorized) environment.
"""
gym_reset_kwargs = gym_reset_kwargs or {}
obs_NO, info_N = self.env.reset(**gym_reset_kwargs)
# TODO: hack, wrap envpool envs such that they don't return a dict
@ -368,16 +371,17 @@ class Collector(BaseCollector):
self._is_closed = False
self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0
@override
def close(self) -> None:
super().close()
self._pre_collect_obs_RO = None
self._pre_collect_info_R = None
@override
def reset_env(
self,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Reset the environments and the initial obs, info, and hidden state of the collector."""
obs_NO, info_N = super().reset_env(gym_reset_kwargs=gym_reset_kwargs)
# We assume that R = N when reset is called.
# TODO: there is currently no mechanism that ensures this and it's a public method!
@ -457,6 +461,8 @@ class Collector(BaseCollector):
ready_env_ids_R = np.arange(self.env_num)
elif n_episode is not None:
ready_env_ids_R = np.arange(min(self.env_num, n_episode))
else:
raise ValueError("Either n_step or n_episode should be set.")
start_time = time.time()
if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None:
@ -645,8 +651,8 @@ class Collector(BaseCollector):
collect_speed=step_count / collect_time,
)
@staticmethod
def _reset_hidden_state_based_on_type(
self,
env_ind_local_D: np.ndarray,
last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None,
) -> None:
@ -700,21 +706,13 @@ class AsyncCollector(Collector):
self._current_action_in_all_envs_EA: np.ndarray = np.empty(self.env_num)
self._current_policy_in_all_envs_E: Batch | None = None
@override
def reset(
self,
reset_buffer: bool = True,
reset_stats: bool = True,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Reset the environment, statistics, and data needed to start the collection.
:param reset_buffer: if true, reset the replay buffer attached
to the collector.
:param reset_stats: if true, reset the statistics attached to the collector.
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments)
:return: The initial observation and info from the environment.
"""
# This sets the _pre_collect attrs
result = super().reset(
reset_buffer=reset_buffer,