212 lines
6.9 KiB
Python
212 lines
6.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from torch.autograd import Variable
|
|
import math
|
|
from numbers import Number
|
|
from torch import distributions as pyd
|
|
from torch.distributions.utils import _standard_normal
|
|
from torch.distributions.independent import Independent
|
|
|
|
from torch.distributions import Distribution, constraints
|
|
from torch.distributions.utils import broadcast_all
|
|
|
|
CONST_SQRT_2 = math.sqrt(2)
|
|
CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi)
|
|
CONST_INV_SQRT_2 = 1 / math.sqrt(2)
|
|
CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI)
|
|
CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e)
|
|
|
|
|
|
class TruncatedStandardNormal(Distribution):
|
|
"""
|
|
Truncated Standard Normal distribution
|
|
https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
|
"""
|
|
|
|
arg_constraints = {
|
|
'a': constraints.real,
|
|
'b': constraints.real,
|
|
}
|
|
has_rsample = True
|
|
|
|
def __init__(self, a=-1.0, b=1.0, validate_args=None):
|
|
self.a, self.b = broadcast_all(a, b)
|
|
if isinstance(a, Number) and isinstance(b, Number):
|
|
batch_shape = torch.Size()
|
|
else:
|
|
batch_shape = self.a.size()
|
|
super(TruncatedStandardNormal, self).__init__(batch_shape, validate_args=validate_args)
|
|
if self.a.dtype != self.b.dtype:
|
|
raise ValueError('Truncation bounds types are different')
|
|
if any((self.a >= self.b).view(-1,).tolist()):
|
|
raise ValueError('Incorrect truncation range')
|
|
eps = torch.finfo(self.a.dtype).eps
|
|
self._dtype_min_gt_0 = eps
|
|
self._dtype_max_lt_1 = 1 - eps
|
|
self._little_phi_a = self._little_phi(self.a)
|
|
self._little_phi_b = self._little_phi(self.b)
|
|
self._big_phi_a = self._big_phi(self.a)
|
|
self._big_phi_b = self._big_phi(self.b)
|
|
self._Z = (self._big_phi_b - self._big_phi_a).clamp_min(eps)
|
|
self._log_Z = self._Z.log()
|
|
little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan)
|
|
little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan)
|
|
self._lpbb_m_lpaa_d_Z = (self._little_phi_b * little_phi_coeff_b - self._little_phi_a * little_phi_coeff_a) / self._Z
|
|
self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z
|
|
self._variance = 1 - self._lpbb_m_lpaa_d_Z - ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2
|
|
self._entropy = CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z
|
|
|
|
@constraints.dependent_property
|
|
def support(self):
|
|
return constraints.interval(self.a, self.b)
|
|
|
|
@property
|
|
def mean(self):
|
|
return self._mean
|
|
|
|
@property
|
|
def variance(self):
|
|
return self._variance
|
|
|
|
@property
|
|
def entropy(self):
|
|
return self._entropy
|
|
|
|
@property
|
|
def auc(self):
|
|
return self._Z
|
|
|
|
@staticmethod
|
|
def _little_phi(x):
|
|
return (-(x ** 2) * 0.5).exp() * CONST_INV_SQRT_2PI
|
|
|
|
@staticmethod
|
|
def _big_phi(x):
|
|
return 0.5 * (1 + (x * CONST_INV_SQRT_2).erf())
|
|
|
|
@staticmethod
|
|
def _inv_big_phi(x):
|
|
return CONST_SQRT_2 * (2 * x - 1).erfinv()
|
|
|
|
def cdf(self, value):
|
|
if self._validate_args:
|
|
self._validate_sample(value)
|
|
return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1)
|
|
|
|
def icdf(self, value):
|
|
return self._inv_big_phi(self._big_phi_a + value * self._Z)
|
|
|
|
def log_prob(self, value):
|
|
if self._validate_args:
|
|
self._validate_sample(value)
|
|
return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value ** 2) * 0.5
|
|
|
|
def rsample(self, sample_shape=torch.Size()):
|
|
shape = self._extended_shape(sample_shape)
|
|
p = torch.empty(shape, device=self.a.device).uniform_(self._dtype_min_gt_0, self._dtype_max_lt_1)
|
|
return self.icdf(p)
|
|
|
|
|
|
class TruncatedNormal(TruncatedStandardNormal):
|
|
"""
|
|
Truncated Normal distribution
|
|
https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
|
"""
|
|
|
|
has_rsample = True
|
|
|
|
def __init__(self, loc, scale, a=-1.0, b=1.0, validate_args=None):
|
|
self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b)
|
|
a = (a - self.loc) / self.scale
|
|
b = (b - self.loc) / self.scale
|
|
super(TruncatedNormal, self).__init__(a, b, validate_args=validate_args)
|
|
self._log_scale = self.scale.log()
|
|
self._mean = self._mean * self.scale + self.loc
|
|
self._variance = self._variance * self.scale ** 2
|
|
self._entropy += self._log_scale
|
|
|
|
def _to_std_rv(self, value):
|
|
return (value - self.loc) / self.scale
|
|
|
|
def _from_std_rv(self, value):
|
|
return value * self.scale + self.loc
|
|
|
|
def cdf(self, value):
|
|
return super(TruncatedNormal, self).cdf(self._to_std_rv(value))
|
|
|
|
def icdf(self, value):
|
|
return self._from_std_rv(super(TruncatedNormal, self).icdf(value))
|
|
|
|
def log_prob(self, value):
|
|
return super(TruncatedNormal, self).log_prob(self._to_std_rv(value)) - self._log_scale
|
|
|
|
|
|
class TanhTransform(pyd.transforms.Transform):
|
|
domain = pyd.constraints.real
|
|
codomain = pyd.constraints.interval(-1.0, 1.0)
|
|
bijective = True
|
|
sign = +1
|
|
|
|
def __init__(self, cache_size=1):
|
|
super().__init__(cache_size=cache_size)
|
|
|
|
@staticmethod
|
|
def atanh(x):
|
|
return 0.5 * (x.log1p() - (-x).log1p())
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, TanhTransform)
|
|
|
|
def _call(self, x):
|
|
return x.tanh()
|
|
|
|
def _inverse(self, y):
|
|
# We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
|
|
# one should use `cache_size=1` instead
|
|
return self.atanh(y)
|
|
|
|
def log_abs_det_jacobian(self, x, y):
|
|
# We use a formula that is more numerically stable, see details in the following link
|
|
# https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
|
|
return 2. * (math.log(2.) - x - F.softplus(-2. * x))
|
|
|
|
class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
|
|
def __init__(self, loc, scale):
|
|
self.loc = loc
|
|
self.scale = scale
|
|
|
|
self.base_dist = pyd.Normal(loc, scale)
|
|
# self.base_dist = Independent(pyd.Normal(loc, scale), 1)
|
|
transforms = [TanhTransform()]
|
|
super().__init__(self.base_dist, transforms)
|
|
|
|
@property
|
|
def mean(self):
|
|
mu = self.loc
|
|
for tr in self.transforms:
|
|
mu = tr(mu)
|
|
return mu
|
|
|
|
|
|
class ContDist:
|
|
def __init__(self, dist=None):
|
|
super().__init__()
|
|
self._dist = dist
|
|
self.mean = dist.mean
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._dist, name)
|
|
|
|
def entropy(self):
|
|
return self._dist.entropy()
|
|
|
|
def mode(self):
|
|
return self._dist.mean
|
|
|
|
def sample(self, sample_shape=()):
|
|
return self._dist.rsample(sample_shape)
|
|
|
|
def log_prob(self, x):
|
|
return self._dist.log_prob(x) |