Collector: move @override, removed docstrings from overridden methods
This commit is contained in:
parent
26a6cca76e
commit
82f425e9fe
@ -237,7 +237,10 @@ class BaseCollector(ABC):
|
||||
self,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Reset the environments and the initial obs, info, and hidden state of the collector."""
|
||||
"""Reset the environments and the initial obs, info, and hidden state of the collector.
|
||||
|
||||
:return: The initial observation and info from the (vectorized) environment.
|
||||
"""
|
||||
gym_reset_kwargs = gym_reset_kwargs or {}
|
||||
obs_NO, info_N = self.env.reset(**gym_reset_kwargs)
|
||||
# TODO: hack, wrap envpool envs such that they don't return a dict
|
||||
@ -368,16 +371,17 @@ class Collector(BaseCollector):
|
||||
self._is_closed = False
|
||||
self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0
|
||||
|
||||
@override
|
||||
def close(self) -> None:
|
||||
super().close()
|
||||
self._pre_collect_obs_RO = None
|
||||
self._pre_collect_info_R = None
|
||||
|
||||
@override
|
||||
def reset_env(
|
||||
self,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Reset the environments and the initial obs, info, and hidden state of the collector."""
|
||||
obs_NO, info_N = super().reset_env(gym_reset_kwargs=gym_reset_kwargs)
|
||||
# We assume that R = N when reset is called.
|
||||
# TODO: there is currently no mechanism that ensures this and it's a public method!
|
||||
@ -457,6 +461,8 @@ class Collector(BaseCollector):
|
||||
ready_env_ids_R = np.arange(self.env_num)
|
||||
elif n_episode is not None:
|
||||
ready_env_ids_R = np.arange(min(self.env_num, n_episode))
|
||||
else:
|
||||
raise ValueError("Either n_step or n_episode should be set.")
|
||||
|
||||
start_time = time.time()
|
||||
if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None:
|
||||
@ -645,8 +651,8 @@ class Collector(BaseCollector):
|
||||
collect_speed=step_count / collect_time,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reset_hidden_state_based_on_type(
|
||||
self,
|
||||
env_ind_local_D: np.ndarray,
|
||||
last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None,
|
||||
) -> None:
|
||||
@ -700,21 +706,13 @@ class AsyncCollector(Collector):
|
||||
self._current_action_in_all_envs_EA: np.ndarray = np.empty(self.env_num)
|
||||
self._current_policy_in_all_envs_E: Batch | None = None
|
||||
|
||||
@override
|
||||
def reset(
|
||||
self,
|
||||
reset_buffer: bool = True,
|
||||
reset_stats: bool = True,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Reset the environment, statistics, and data needed to start the collection.
|
||||
|
||||
:param reset_buffer: if true, reset the replay buffer attached
|
||||
to the collector.
|
||||
:param reset_stats: if true, reset the statistics attached to the collector.
|
||||
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
|
||||
reset function. Defaults to None (extra keyword arguments)
|
||||
:return: The initial observation and info from the environment.
|
||||
"""
|
||||
# This sets the _pre_collect attrs
|
||||
result = super().reset(
|
||||
reset_buffer=reset_buffer,
|
||||
|
Loading…
x
Reference in New Issue
Block a user