14 lines
272 B
Python
Raw Normal View History

import torch
class DiagGaussian(torch.distributions.Normal):
"""Diagonal Gaussian Distribution
"""
def log_prob(self, actions):
return super().log_prob(actions).sum(-1, keepdim=True)
def entropy(self):
return super().entropy().sum(-1)