diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index e66a5f8..824e19a 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -116,9 +116,9 @@ class PPOPolicy(A2CPolicy): surr1 = ratio * b.adv surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv if self._dual_clip: - clip_loss = -torch.max( - torch.min(surr1, surr2), self._dual_clip * b.adv - ).mean() + clip1 = torch.min(surr1, surr2) + clip2 = torch.max(clip1, self._dual_clip * b.adv) + clip_loss = -torch.where(b.adv < 0, clip2, clip1).mean() else: clip_loss = -torch.min(surr1, surr2).mean() # calculate loss for critic