import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from torch.autograd import Variable import math from numbers import Number from torch import distributions as pyd from torch.distributions.utils import _standard_normal from torch.distributions.independent import Independent from torch.distributions import Distribution, constraints from torch.distributions.utils import broadcast_all CONST_SQRT_2 = math.sqrt(2) CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi) CONST_INV_SQRT_2 = 1 / math.sqrt(2) CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI) CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e) class TruncatedStandardNormal(Distribution): """ Truncated Standard Normal distribution https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf """ arg_constraints = { 'a': constraints.real, 'b': constraints.real, } has_rsample = True def __init__(self, a=-1.0, b=1.0, validate_args=None): self.a, self.b = broadcast_all(a, b) if isinstance(a, Number) and isinstance(b, Number): batch_shape = torch.Size() else: batch_shape = self.a.size() super(TruncatedStandardNormal, self).__init__(batch_shape, validate_args=validate_args) if self.a.dtype != self.b.dtype: raise ValueError('Truncation bounds types are different') if any((self.a >= self.b).view(-1,).tolist()): raise ValueError('Incorrect truncation range') eps = torch.finfo(self.a.dtype).eps self._dtype_min_gt_0 = eps self._dtype_max_lt_1 = 1 - eps self._little_phi_a = self._little_phi(self.a) self._little_phi_b = self._little_phi(self.b) self._big_phi_a = self._big_phi(self.a) self._big_phi_b = self._big_phi(self.b) self._Z = (self._big_phi_b - self._big_phi_a).clamp_min(eps) self._log_Z = self._Z.log() little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan) little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan) self._lpbb_m_lpaa_d_Z = (self._little_phi_b * little_phi_coeff_b - self._little_phi_a * little_phi_coeff_a) / self._Z self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z self._variance = 1 - self._lpbb_m_lpaa_d_Z - ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2 self._entropy = CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z @constraints.dependent_property def support(self): return constraints.interval(self.a, self.b) @property def mean(self): return self._mean @property def variance(self): return self._variance @property def entropy(self): return self._entropy @property def auc(self): return self._Z @staticmethod def _little_phi(x): return (-(x ** 2) * 0.5).exp() * CONST_INV_SQRT_2PI @staticmethod def _big_phi(x): return 0.5 * (1 + (x * CONST_INV_SQRT_2).erf()) @staticmethod def _inv_big_phi(x): return CONST_SQRT_2 * (2 * x - 1).erfinv() def cdf(self, value): if self._validate_args: self._validate_sample(value) return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1) def icdf(self, value): return self._inv_big_phi(self._big_phi_a + value * self._Z) def log_prob(self, value): if self._validate_args: self._validate_sample(value) return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value ** 2) * 0.5 def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) p = torch.empty(shape, device=self.a.device).uniform_(self._dtype_min_gt_0, self._dtype_max_lt_1) return self.icdf(p) class TruncatedNormal(TruncatedStandardNormal): """ Truncated Normal distribution https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf """ has_rsample = True def __init__(self, loc, scale, a=-1.0, b=1.0, validate_args=None): self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b) a = (a - self.loc) / self.scale b = (b - self.loc) / self.scale super(TruncatedNormal, self).__init__(a, b, validate_args=validate_args) self._log_scale = self.scale.log() self._mean = self._mean * self.scale + self.loc self._variance = self._variance * self.scale ** 2 self._entropy += self._log_scale def _to_std_rv(self, value): return (value - self.loc) / self.scale def _from_std_rv(self, value): return value * self.scale + self.loc def cdf(self, value): return super(TruncatedNormal, self).cdf(self._to_std_rv(value)) def icdf(self, value): return self._from_std_rv(super(TruncatedNormal, self).icdf(value)) def log_prob(self, value): return super(TruncatedNormal, self).log_prob(self._to_std_rv(value)) - self._log_scale class TanhTransform(pyd.transforms.Transform): domain = pyd.constraints.real codomain = pyd.constraints.interval(-1.0, 1.0) bijective = True sign = +1 def __init__(self, cache_size=1): super().__init__(cache_size=cache_size) @staticmethod def atanh(x): return 0.5 * (x.log1p() - (-x).log1p()) def __eq__(self, other): return isinstance(other, TanhTransform) def _call(self, x): return x.tanh() def _inverse(self, y): # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. # one should use `cache_size=1` instead return self.atanh(y) def log_abs_det_jacobian(self, x, y): # We use a formula that is more numerically stable, see details in the following link # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 return 2. * (math.log(2.) - x - F.softplus(-2. * x)) class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): def __init__(self, loc, scale): self.loc = loc self.scale = scale self.base_dist = pyd.Normal(loc, scale) # self.base_dist = Independent(pyd.Normal(loc, scale), 1) transforms = [TanhTransform()] super().__init__(self.base_dist, transforms) @property def mean(self): mu = self.loc for tr in self.transforms: mu = tr(mu) return mu class ContDist: def __init__(self, dist=None): super().__init__() self._dist = dist self.mean = dist.mean def __getattr__(self, name): return getattr(self._dist, name) def entropy(self): return self._dist.entropy() def mode(self): return self._dist.mean def sample(self, sample_shape=()): return self._dist.rsample(sample_shape) def log_prob(self, x): return self._dist.log_prob(x)