parent
8a5e2190f7
commit
a740496a51
@ -116,9 +116,9 @@ class PPOPolicy(A2CPolicy):
|
||||
surr1 = ratio * b.adv
|
||||
surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv
|
||||
if self._dual_clip:
|
||||
clip_loss = -torch.max(
|
||||
torch.min(surr1, surr2), self._dual_clip * b.adv
|
||||
).mean()
|
||||
clip1 = torch.min(surr1, surr2)
|
||||
clip2 = torch.max(clip1, self._dual_clip * b.adv)
|
||||
clip_loss = -torch.where(b.adv < 0, clip2, clip1).mean()
|
||||
else:
|
||||
clip_loss = -torch.min(surr1, surr2).mean()
|
||||
# calculate loss for critic
|
||||
|
Loading…
x
Reference in New Issue
Block a user