Added support for clipping to DQNPolicy (#642)
* When clip_loss_grad=True is passed, Huber loss is used instead of the MSE loss. * Made the argument's name more descriptive; * Replaced the smooth L1 loss with the Huber loss, which has an identical form to the default parametrization, but seems to be better known in this context; * Added a fuller description to the docstring;
This commit is contained in:
parent
c87b9f49bc
commit
277138ca5b
@ -26,6 +26,9 @@ class DQNPolicy(BasePolicy):
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param bool is_double: use double dqn. Default to True.
|
||||
:param bool clip_loss_grad: clip the gradient of the loss in accordance
|
||||
with nature14236; this amounts to using the Huber loss instead of
|
||||
the MSE loss. Default to False.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
@ -44,6 +47,7 @@ class DQNPolicy(BasePolicy):
|
||||
target_update_freq: int = 0,
|
||||
reward_normalization: bool = False,
|
||||
is_double: bool = True,
|
||||
clip_loss_grad: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
@ -62,6 +66,7 @@ class DQNPolicy(BasePolicy):
|
||||
self.model_old.eval()
|
||||
self._rew_norm = reward_normalization
|
||||
self._is_double = is_double
|
||||
self._clip_loss_grad = clip_loss_grad
|
||||
|
||||
def set_eps(self, eps: float) -> None:
|
||||
"""Set the eps for epsilon-greedy exploration."""
|
||||
@ -168,7 +173,14 @@ class DQNPolicy(BasePolicy):
|
||||
q = q[np.arange(len(q)), batch.act]
|
||||
returns = to_torch_as(batch.returns.flatten(), q)
|
||||
td_error = returns - q
|
||||
loss = (td_error.pow(2) * weight).mean()
|
||||
|
||||
if self._clip_loss_grad:
|
||||
y = q.reshape(-1, 1)
|
||||
t = returns.reshape(-1, 1)
|
||||
loss = torch.nn.functional.huber_loss(y, t, reduction="mean")
|
||||
else:
|
||||
loss = (td_error.pow(2) * weight).mean()
|
||||
|
||||
batch.weight = td_error # prio-buffer
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
|
Loading…
x
Reference in New Issue
Block a user