diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index e36b7f5..d3a7999 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -26,6 +26,9 @@ class DQNPolicy(BasePolicy): :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. :param bool is_double: use double dqn. Default to True. + :param bool clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. Default to False. :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update(). Default to None (no lr_scheduler). @@ -44,6 +47,7 @@ class DQNPolicy(BasePolicy): target_update_freq: int = 0, reward_normalization: bool = False, is_double: bool = True, + clip_loss_grad: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -62,6 +66,7 @@ class DQNPolicy(BasePolicy): self.model_old.eval() self._rew_norm = reward_normalization self._is_double = is_double + self._clip_loss_grad = clip_loss_grad def set_eps(self, eps: float) -> None: """Set the eps for epsilon-greedy exploration.""" @@ -168,7 +173,14 @@ class DQNPolicy(BasePolicy): q = q[np.arange(len(q)), batch.act] returns = to_torch_as(batch.returns.flatten(), q) td_error = returns - q - loss = (td_error.pow(2) * weight).mean() + + if self._clip_loss_grad: + y = q.reshape(-1, 1) + t = returns.reshape(-1, 1) + loss = torch.nn.functional.huber_loss(y, t, reduction="mean") + else: + loss = (td_error.pow(2) * weight).mean() + batch.weight = td_error # prio-buffer loss.backward() self.optim.step()