Collector: removed unnecessary no-grad flag from interfaces. Breaking
This commit is contained in:
parent
f876198870
commit
c5d0e169b5
@ -255,18 +255,17 @@ class BaseCollector(ABC):
|
||||
n_episode: int | None = None,
|
||||
random: bool = False,
|
||||
render: float | None = None,
|
||||
no_grad: bool = True,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
) -> CollectStats:
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def collect(
|
||||
self,
|
||||
n_step: int | None = None,
|
||||
n_episode: int | None = None,
|
||||
random: bool = False,
|
||||
render: float | None = None,
|
||||
no_grad: bool = True,
|
||||
reset_before_collect: bool = False,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
) -> CollectStats:
|
||||
@ -304,7 +303,6 @@ class BaseCollector(ABC):
|
||||
n_episode=n_episode,
|
||||
random=random,
|
||||
render=render,
|
||||
no_grad=no_grad,
|
||||
gym_reset_kwargs=gym_reset_kwargs,
|
||||
)
|
||||
|
||||
@ -398,7 +396,6 @@ class Collector(BaseCollector):
|
||||
self,
|
||||
random: bool,
|
||||
ready_env_ids_R: np.ndarray,
|
||||
use_grad: bool,
|
||||
last_obs_RO: np.ndarray,
|
||||
last_info_R: np.ndarray,
|
||||
last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None,
|
||||
@ -420,11 +417,10 @@ class Collector(BaseCollector):
|
||||
info_batch = _HACKY_create_info_batch(last_info_R)
|
||||
obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch))
|
||||
|
||||
with torch.set_grad_enabled(use_grad):
|
||||
act_batch_RA = self.policy(
|
||||
obs_batch_R,
|
||||
last_hidden_state_RH,
|
||||
)
|
||||
act_batch_RA = self.policy(
|
||||
obs_batch_R,
|
||||
last_hidden_state_RH,
|
||||
)
|
||||
|
||||
act_RA = to_numpy(act_batch_RA.act)
|
||||
if self.exploration_noise:
|
||||
@ -454,7 +450,6 @@ class Collector(BaseCollector):
|
||||
n_episode: int | None = None,
|
||||
random: bool = False,
|
||||
render: float | None = None,
|
||||
no_grad: bool = True,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
) -> CollectStats:
|
||||
# TODO: can't do it init since AsyncCollector is currently a subclass of Collector
|
||||
@ -469,8 +464,6 @@ class Collector(BaseCollector):
|
||||
elif n_episode is not None:
|
||||
ready_env_ids_R = np.arange(min(self.env_num, n_episode))
|
||||
|
||||
use_grad = not no_grad
|
||||
|
||||
start_time = time.time()
|
||||
if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None:
|
||||
raise ValueError(
|
||||
@ -513,7 +506,6 @@ class Collector(BaseCollector):
|
||||
) = self._compute_action_policy_hidden(
|
||||
random=random,
|
||||
ready_env_ids_R=ready_env_ids_R,
|
||||
use_grad=use_grad,
|
||||
last_obs_RO=last_obs_RO,
|
||||
last_info_R=last_info_R,
|
||||
last_hidden_state_RH=last_hidden_state_RH,
|
||||
@ -762,10 +754,8 @@ class AsyncCollector(Collector):
|
||||
n_episode: int | None = None,
|
||||
random: bool = False,
|
||||
render: float | None = None,
|
||||
no_grad: bool = True,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
) -> CollectStats:
|
||||
use_grad = not no_grad
|
||||
start_time = time.time()
|
||||
|
||||
step_count = 0
|
||||
@ -823,7 +813,6 @@ class AsyncCollector(Collector):
|
||||
) = self._compute_action_policy_hidden(
|
||||
random=random,
|
||||
ready_env_ids_R=ready_env_ids_R,
|
||||
use_grad=use_grad,
|
||||
last_obs_RO=last_obs_RO,
|
||||
last_info_R=last_info_R,
|
||||
last_hidden_state_RH=last_hidden_state_RH,
|
||||
|
Loading…
x
Reference in New Issue
Block a user