2020-05-27 11:02:23 +08:00

12 lines
267 B
Python

import torch
class DiagGaussian(torch.distributions.Normal):
"""Diagonal Gaussian distribution."""
def log_prob(self, actions):
return super().log_prob(actions).sum(-1, keepdim=True)
def entropy(self):
return super().entropy().sum(-1)