Add a comment before SAC alpha loss (#565)

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
Alex Nikulkov 2022-03-08 14:38:42 -08:00 committed by GitHub
parent ad2e1eaea0
commit 74f430ea36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 1 deletions

View File

@ -364,6 +364,7 @@ class AsyncCollector(Collector):
exploration_noise: bool = False,
) -> None:
# assert env.is_async
warnings.warn("Using async setting may collect extra transitions into buffer.")
super().__init__(policy, env, buffer, preprocess_fn, exploration_noise)
def reset_env(self) -> None:
@ -424,7 +425,6 @@ class AsyncCollector(Collector):
"Please specify at least one (either n_step or n_episode) "
"in AsyncCollector.collect()."
)
warnings.warn("Using async setting may collect extra transitions into buffer.")
ready_env_ids = self._ready_env_ids

View File

@ -174,6 +174,7 @@ class SACPolicy(DDPGPolicy):
if self._is_auto_alpha:
log_prob = obs_result.log_prob.detach() + self._target_entropy
# please take a look at issue #258 if you'd like to change this line
alpha_loss = -(self._log_alpha * log_prob).mean()
self._alpha_optim.zero_grad()
alpha_loss.backward()