Collector: removed unnecessary no-grad flag from interfaces. Breaking

This commit is contained in:
Michael Panchenko 2024-05-05 15:41:20 +02:00
parent f876198870
commit c5d0e169b5

View File

@ -255,18 +255,17 @@ class BaseCollector(ABC):
n_episode: int | None = None, n_episode: int | None = None,
random: bool = False, random: bool = False,
render: float | None = None, render: float | None = None,
no_grad: bool = True,
gym_reset_kwargs: dict[str, Any] | None = None, gym_reset_kwargs: dict[str, Any] | None = None,
) -> CollectStats: ) -> CollectStats:
pass pass
@torch.no_grad()
def collect( def collect(
self, self,
n_step: int | None = None, n_step: int | None = None,
n_episode: int | None = None, n_episode: int | None = None,
random: bool = False, random: bool = False,
render: float | None = None, render: float | None = None,
no_grad: bool = True,
reset_before_collect: bool = False, reset_before_collect: bool = False,
gym_reset_kwargs: dict[str, Any] | None = None, gym_reset_kwargs: dict[str, Any] | None = None,
) -> CollectStats: ) -> CollectStats:
@ -304,7 +303,6 @@ class BaseCollector(ABC):
n_episode=n_episode, n_episode=n_episode,
random=random, random=random,
render=render, render=render,
no_grad=no_grad,
gym_reset_kwargs=gym_reset_kwargs, gym_reset_kwargs=gym_reset_kwargs,
) )
@ -398,7 +396,6 @@ class Collector(BaseCollector):
self, self,
random: bool, random: bool,
ready_env_ids_R: np.ndarray, ready_env_ids_R: np.ndarray,
use_grad: bool,
last_obs_RO: np.ndarray, last_obs_RO: np.ndarray,
last_info_R: np.ndarray, last_info_R: np.ndarray,
last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None, last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None,
@ -420,11 +417,10 @@ class Collector(BaseCollector):
info_batch = _HACKY_create_info_batch(last_info_R) info_batch = _HACKY_create_info_batch(last_info_R)
obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch)) obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch))
with torch.set_grad_enabled(use_grad): act_batch_RA = self.policy(
act_batch_RA = self.policy( obs_batch_R,
obs_batch_R, last_hidden_state_RH,
last_hidden_state_RH, )
)
act_RA = to_numpy(act_batch_RA.act) act_RA = to_numpy(act_batch_RA.act)
if self.exploration_noise: if self.exploration_noise:
@ -454,7 +450,6 @@ class Collector(BaseCollector):
n_episode: int | None = None, n_episode: int | None = None,
random: bool = False, random: bool = False,
render: float | None = None, render: float | None = None,
no_grad: bool = True,
gym_reset_kwargs: dict[str, Any] | None = None, gym_reset_kwargs: dict[str, Any] | None = None,
) -> CollectStats: ) -> CollectStats:
# TODO: can't do it init since AsyncCollector is currently a subclass of Collector # TODO: can't do it init since AsyncCollector is currently a subclass of Collector
@ -469,8 +464,6 @@ class Collector(BaseCollector):
elif n_episode is not None: elif n_episode is not None:
ready_env_ids_R = np.arange(min(self.env_num, n_episode)) ready_env_ids_R = np.arange(min(self.env_num, n_episode))
use_grad = not no_grad
start_time = time.time() start_time = time.time()
if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None:
raise ValueError( raise ValueError(
@ -513,7 +506,6 @@ class Collector(BaseCollector):
) = self._compute_action_policy_hidden( ) = self._compute_action_policy_hidden(
random=random, random=random,
ready_env_ids_R=ready_env_ids_R, ready_env_ids_R=ready_env_ids_R,
use_grad=use_grad,
last_obs_RO=last_obs_RO, last_obs_RO=last_obs_RO,
last_info_R=last_info_R, last_info_R=last_info_R,
last_hidden_state_RH=last_hidden_state_RH, last_hidden_state_RH=last_hidden_state_RH,
@ -762,10 +754,8 @@ class AsyncCollector(Collector):
n_episode: int | None = None, n_episode: int | None = None,
random: bool = False, random: bool = False,
render: float | None = None, render: float | None = None,
no_grad: bool = True,
gym_reset_kwargs: dict[str, Any] | None = None, gym_reset_kwargs: dict[str, Any] | None = None,
) -> CollectStats: ) -> CollectStats:
use_grad = not no_grad
start_time = time.time() start_time = time.time()
step_count = 0 step_count = 0
@ -823,7 +813,6 @@ class AsyncCollector(Collector):
) = self._compute_action_policy_hidden( ) = self._compute_action_policy_hidden(
random=random, random=random,
ready_env_ids_R=ready_env_ids_R, ready_env_ids_R=ready_env_ids_R,
use_grad=use_grad,
last_obs_RO=last_obs_RO, last_obs_RO=last_obs_RO,
last_info_R=last_info_R, last_info_R=last_info_R,
last_hidden_state_RH=last_hidden_state_RH, last_hidden_state_RH=last_hidden_state_RH,