96 lines
3.9 KiB
Python
96 lines
3.9 KiB
Python
|
import numpy as np
|
||
|
import torch
|
||
|
from scipy import integrate
|
||
|
|
||
|
|
||
|
def prior_likelihood(z, sigma):
|
||
|
"""The likelihood of a Gaussian distribution with mean zero and
|
||
|
standard deviation sigma."""
|
||
|
shape = z.shape
|
||
|
N = np.prod(shape[1:])
|
||
|
return -N / 2. * torch.log(2 * np.pi * sigma ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * sigma ** 2)
|
||
|
|
||
|
|
||
|
def ode_likelihood(x,
|
||
|
score_model,
|
||
|
marginal_prob_std,
|
||
|
diffusion_coeff,
|
||
|
batch_size=64,
|
||
|
device='cuda',
|
||
|
eps=1e-5):
|
||
|
"""Compute the likelihood with probability flow ODE.
|
||
|
|
||
|
Args:
|
||
|
x: Input data.
|
||
|
score_model: A PyTorch model representing the score-based model.
|
||
|
marginal_prob_std: A function that gives the standard deviation of the
|
||
|
perturbation kernel.
|
||
|
diffusion_coeff: A function that gives the diffusion coefficient of the
|
||
|
forward SDE.
|
||
|
batch_size: The batch size. Equals to the leading dimension of `x`.
|
||
|
device: 'cuda' for evaluation on GPUs, and 'cpu' for evaluation on CPUs.
|
||
|
eps: A `float` number. The smallest time step for numerical stability.
|
||
|
|
||
|
Returns:
|
||
|
z: The latent code for `x`.
|
||
|
bpd: The log-likelihoods in bits/dim.
|
||
|
"""
|
||
|
|
||
|
# Draw the random Gaussian sample for Skilling-Hutchinson's estimator.
|
||
|
epsilon = torch.randn_like(x)
|
||
|
|
||
|
def divergence_eval(sample, time_steps, epsilon):
|
||
|
"""Compute the divergence of the score-based model with Skilling-Hutchinson."""
|
||
|
with torch.enable_grad():
|
||
|
sample.requires_grad_(True)
|
||
|
score_e = torch.sum(score_model(sample, time_steps) * epsilon)
|
||
|
grad_score_e = torch.autograd.grad(score_e, sample)[0]
|
||
|
return torch.sum(grad_score_e * epsilon, dim=(1, 2, 3))
|
||
|
|
||
|
shape = x.shape
|
||
|
|
||
|
def score_eval_wrapper(sample, time_steps):
|
||
|
"""A wrapper for evaluating the score-based model for the black-box ODE solver."""
|
||
|
sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
|
||
|
time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0],))
|
||
|
with torch.no_grad():
|
||
|
score = score_model(sample, time_steps)
|
||
|
return score.cpu().numpy().reshape((-1,)).astype(np.float64)
|
||
|
|
||
|
def divergence_eval_wrapper(sample, time_steps):
|
||
|
"""A wrapper for evaluating the divergence of score for the black-box ODE solver."""
|
||
|
with torch.no_grad():
|
||
|
# Obtain x(t) by solving the probability flow ODE.
|
||
|
sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
|
||
|
time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0],))
|
||
|
# Compute likelihood.
|
||
|
div = divergence_eval(sample, time_steps, epsilon)
|
||
|
return div.cpu().numpy().reshape((-1,)).astype(np.float64)
|
||
|
|
||
|
def ode_func(t, x):
|
||
|
"""The ODE function for the black-box solver."""
|
||
|
time_steps = np.ones((shape[0],)) * t
|
||
|
sample = x[:-shape[0]]
|
||
|
logp = x[-shape[0]:]
|
||
|
g = diffusion_coeff(torch.tensor(t)).cpu().numpy()
|
||
|
sample_grad = -0.5 * g ** 2 * score_eval_wrapper(sample, time_steps)
|
||
|
logp_grad = -0.5 * g ** 2 * divergence_eval_wrapper(sample, time_steps)
|
||
|
return np.concatenate([sample_grad, logp_grad], axis=0)
|
||
|
|
||
|
|
||
|
init = np.concatenate([x.cpu().numpy().reshape((-1,)), np.zeros((shape[0],))], axis=0)
|
||
|
# Black-box ODE solver
|
||
|
res = integrate.solve_ivp(ode_func, (eps, 1.), init, rtol=1e-5, atol=1e-5, method='RK45')
|
||
|
zp = torch.tensor(res.y[:, -1], device=device)
|
||
|
z = zp[:-shape[0]].reshape(shape)
|
||
|
delta_logp = zp[-shape[0]:].reshape(shape[0])
|
||
|
sigma_max = marginal_prob_std(1.)
|
||
|
prior_logp = prior_likelihood(z, sigma_max)
|
||
|
bpd = -(prior_logp + delta_logp) / np.log(2)
|
||
|
N = np.prod(shape[1:])
|
||
|
bpd = bpd / N + 8.
|
||
|
return z, bpd
|
||
|
|
||
|
|
||
|
|