From 277138ca5b050518aacaaea367192f910fbe666d Mon Sep 17 00:00:00 2001 From: Michal Gregor Date: Wed, 18 May 2022 13:33:37 +0200 Subject: [PATCH] Added support for clipping to DQNPolicy (#642) * When clip_loss_grad=True is passed, Huber loss is used instead of the MSE loss. * Made the argument's name more descriptive; * Replaced the smooth L1 loss with the Huber loss, which has an identical form to the default parametrization, but seems to be better known in this context; * Added a fuller description to the docstring; --- tianshou/policy/modelfree/dqn.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index e36b7f5..d3a7999 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -26,6 +26,9 @@ class DQNPolicy(BasePolicy): :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. :param bool is_double: use double dqn. Default to True. + :param bool clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. Default to False. :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update(). Default to None (no lr_scheduler). @@ -44,6 +47,7 @@ class DQNPolicy(BasePolicy): target_update_freq: int = 0, reward_normalization: bool = False, is_double: bool = True, + clip_loss_grad: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -62,6 +66,7 @@ class DQNPolicy(BasePolicy): self.model_old.eval() self._rew_norm = reward_normalization self._is_double = is_double + self._clip_loss_grad = clip_loss_grad def set_eps(self, eps: float) -> None: """Set the eps for epsilon-greedy exploration.""" @@ -168,7 +173,14 @@ class DQNPolicy(BasePolicy): q = q[np.arange(len(q)), batch.act] returns = to_torch_as(batch.returns.flatten(), q) td_error = returns - q - loss = (td_error.pow(2) * weight).mean() + + if self._clip_loss_grad: + y = q.reshape(-1, 1) + t = returns.reshape(-1, 1) + loss = torch.nn.functional.huber_loss(y, t, reduction="mean") + else: + loss = (td_error.pow(2) * weight).mean() + batch.weight = td_error # prio-buffer loss.backward() self.optim.step()