parent
8a5e2190f7
commit
a740496a51
@ -116,9 +116,9 @@ class PPOPolicy(A2CPolicy):
|
|||||||
surr1 = ratio * b.adv
|
surr1 = ratio * b.adv
|
||||||
surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv
|
surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv
|
||||||
if self._dual_clip:
|
if self._dual_clip:
|
||||||
clip_loss = -torch.max(
|
clip1 = torch.min(surr1, surr2)
|
||||||
torch.min(surr1, surr2), self._dual_clip * b.adv
|
clip2 = torch.max(clip1, self._dual_clip * b.adv)
|
||||||
).mean()
|
clip_loss = -torch.where(b.adv < 0, clip2, clip1).mean()
|
||||||
else:
|
else:
|
||||||
clip_loss = -torch.min(surr1, surr2).mean()
|
clip_loss = -torch.min(surr1, surr2).mean()
|
||||||
# calculate loss for critic
|
# calculate loss for critic
|
||||||
|
Loading…
x
Reference in New Issue
Block a user