14 lines
272 B
Python
14 lines
272 B
Python
import torch
|
|
|
|
|
|
class DiagGaussian(torch.distributions.Normal):
|
|
"""Diagonal Gaussian Distribution
|
|
|
|
"""
|
|
|
|
def log_prob(self, actions):
|
|
return super().log_prob(actions).sum(-1, keepdim=True)
|
|
|
|
def entropy(self):
|
|
return super().entropy().sum(-1)
|