12 lines
267 B
Python
Raw Normal View History

import torch
class DiagGaussian(torch.distributions.Normal):
2020-05-27 11:02:23 +08:00
"""Diagonal Gaussian distribution."""
def log_prob(self, actions):
return super().log_prob(actions).sum(-1, keepdim=True)
def entropy(self):
return super().entropy().sum(-1)