Collector: removed unnecessary no-grad flag from interfaces. Breaking

This commit is contained in:
Michael Panchenko 2024-05-05 15:41:20 +02:00
parent f876198870
commit c5d0e169b5

View File

@ -255,18 +255,17 @@ class BaseCollector(ABC):
n_episode: int | None = None,
random: bool = False,
render: float | None = None,
no_grad: bool = True,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> CollectStats:
pass
@torch.no_grad()
def collect(
self,
n_step: int | None = None,
n_episode: int | None = None,
random: bool = False,
render: float | None = None,
no_grad: bool = True,
reset_before_collect: bool = False,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> CollectStats:
@ -304,7 +303,6 @@ class BaseCollector(ABC):
n_episode=n_episode,
random=random,
render=render,
no_grad=no_grad,
gym_reset_kwargs=gym_reset_kwargs,
)
@ -398,7 +396,6 @@ class Collector(BaseCollector):
self,
random: bool,
ready_env_ids_R: np.ndarray,
use_grad: bool,
last_obs_RO: np.ndarray,
last_info_R: np.ndarray,
last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None,
@ -420,11 +417,10 @@ class Collector(BaseCollector):
info_batch = _HACKY_create_info_batch(last_info_R)
obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch))
with torch.set_grad_enabled(use_grad):
act_batch_RA = self.policy(
obs_batch_R,
last_hidden_state_RH,
)
act_batch_RA = self.policy(
obs_batch_R,
last_hidden_state_RH,
)
act_RA = to_numpy(act_batch_RA.act)
if self.exploration_noise:
@ -454,7 +450,6 @@ class Collector(BaseCollector):
n_episode: int | None = None,
random: bool = False,
render: float | None = None,
no_grad: bool = True,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> CollectStats:
# TODO: can't do it init since AsyncCollector is currently a subclass of Collector
@ -469,8 +464,6 @@ class Collector(BaseCollector):
elif n_episode is not None:
ready_env_ids_R = np.arange(min(self.env_num, n_episode))
use_grad = not no_grad
start_time = time.time()
if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None:
raise ValueError(
@ -513,7 +506,6 @@ class Collector(BaseCollector):
) = self._compute_action_policy_hidden(
random=random,
ready_env_ids_R=ready_env_ids_R,
use_grad=use_grad,
last_obs_RO=last_obs_RO,
last_info_R=last_info_R,
last_hidden_state_RH=last_hidden_state_RH,
@ -762,10 +754,8 @@ class AsyncCollector(Collector):
n_episode: int | None = None,
random: bool = False,
render: float | None = None,
no_grad: bool = True,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> CollectStats:
use_grad = not no_grad
start_time = time.time()
step_count = 0
@ -823,7 +813,6 @@ class AsyncCollector(Collector):
) = self._compute_action_policy_hidden(
random=random,
ready_env_ids_R=ready_env_ids_R,
use_grad=use_grad,
last_obs_RO=last_obs_RO,
last_info_R=last_info_R,
last_hidden_state_RH=last_hidden_state_RH,