diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 133667f..cfc4b3d 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -255,18 +255,17 @@ class BaseCollector(ABC): n_episode: int | None = None, random: bool = False, render: float | None = None, - no_grad: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: pass + @torch.no_grad() def collect( self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, - no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: @@ -304,7 +303,6 @@ class BaseCollector(ABC): n_episode=n_episode, random=random, render=render, - no_grad=no_grad, gym_reset_kwargs=gym_reset_kwargs, ) @@ -398,7 +396,6 @@ class Collector(BaseCollector): self, random: bool, ready_env_ids_R: np.ndarray, - use_grad: bool, last_obs_RO: np.ndarray, last_info_R: np.ndarray, last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None, @@ -420,11 +417,10 @@ class Collector(BaseCollector): info_batch = _HACKY_create_info_batch(last_info_R) obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch)) - with torch.set_grad_enabled(use_grad): - act_batch_RA = self.policy( - obs_batch_R, - last_hidden_state_RH, - ) + act_batch_RA = self.policy( + obs_batch_R, + last_hidden_state_RH, + ) act_RA = to_numpy(act_batch_RA.act) if self.exploration_noise: @@ -454,7 +450,6 @@ class Collector(BaseCollector): n_episode: int | None = None, random: bool = False, render: float | None = None, - no_grad: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: # TODO: can't do it init since AsyncCollector is currently a subclass of Collector @@ -469,8 +464,6 @@ class Collector(BaseCollector): elif n_episode is not None: ready_env_ids_R = np.arange(min(self.env_num, n_episode)) - use_grad = not no_grad - start_time = time.time() if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: raise ValueError( @@ -513,7 +506,6 @@ class Collector(BaseCollector): ) = self._compute_action_policy_hidden( random=random, ready_env_ids_R=ready_env_ids_R, - use_grad=use_grad, last_obs_RO=last_obs_RO, last_info_R=last_info_R, last_hidden_state_RH=last_hidden_state_RH, @@ -762,10 +754,8 @@ class AsyncCollector(Collector): n_episode: int | None = None, random: bool = False, render: float | None = None, - no_grad: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: - use_grad = not no_grad start_time = time.time() step_count = 0 @@ -823,7 +813,6 @@ class AsyncCollector(Collector): ) = self._compute_action_policy_hidden( random=random, ready_env_ids_R=ready_env_ids_R, - use_grad=use_grad, last_obs_RO=last_obs_RO, last_info_R=last_info_R, last_hidden_state_RH=last_hidden_state_RH,