2020-05-18 16:23:35 +08:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
class DiagGaussian(torch.distributions.Normal):
|
2020-05-27 11:02:23 +08:00
|
|
|
"""Diagonal Gaussian distribution."""
|
2020-05-18 16:23:35 +08:00
|
|
|
|
|
|
|
def log_prob(self, actions):
|
|
|
|
return super().log_prob(actions).sum(-1, keepdim=True)
|
|
|
|
|
|
|
|
def entropy(self):
|
|
|
|
return super().entropy().sum(-1)
|