96 lines
3.9 KiB
Python
Raw Permalink Normal View History

2023-10-23 15:45:14 +02:00
import numpy as np
import torch
from scipy import integrate
def prior_likelihood(z, sigma):
"""The likelihood of a Gaussian distribution with mean zero and
standard deviation sigma."""
shape = z.shape
N = np.prod(shape[1:])
return -N / 2. * torch.log(2 * np.pi * sigma ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * sigma ** 2)
def ode_likelihood(x,
score_model,
marginal_prob_std,
diffusion_coeff,
batch_size=64,
device='cuda',
eps=1e-5):
"""Compute the likelihood with probability flow ODE.
Args:
x: Input data.
score_model: A PyTorch model representing the score-based model.
marginal_prob_std: A function that gives the standard deviation of the
perturbation kernel.
diffusion_coeff: A function that gives the diffusion coefficient of the
forward SDE.
batch_size: The batch size. Equals to the leading dimension of `x`.
device: 'cuda' for evaluation on GPUs, and 'cpu' for evaluation on CPUs.
eps: A `float` number. The smallest time step for numerical stability.
Returns:
z: The latent code for `x`.
bpd: The log-likelihoods in bits/dim.
"""
# Draw the random Gaussian sample for Skilling-Hutchinson's estimator.
epsilon = torch.randn_like(x)
def divergence_eval(sample, time_steps, epsilon):
"""Compute the divergence of the score-based model with Skilling-Hutchinson."""
with torch.enable_grad():
sample.requires_grad_(True)
score_e = torch.sum(score_model(sample, time_steps) * epsilon)
grad_score_e = torch.autograd.grad(score_e, sample)[0]
return torch.sum(grad_score_e * epsilon, dim=(1, 2, 3))
shape = x.shape
def score_eval_wrapper(sample, time_steps):
"""A wrapper for evaluating the score-based model for the black-box ODE solver."""
sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0],))
with torch.no_grad():
score = score_model(sample, time_steps)
return score.cpu().numpy().reshape((-1,)).astype(np.float64)
def divergence_eval_wrapper(sample, time_steps):
"""A wrapper for evaluating the divergence of score for the black-box ODE solver."""
with torch.no_grad():
# Obtain x(t) by solving the probability flow ODE.
sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0],))
# Compute likelihood.
div = divergence_eval(sample, time_steps, epsilon)
return div.cpu().numpy().reshape((-1,)).astype(np.float64)
def ode_func(t, x):
"""The ODE function for the black-box solver."""
time_steps = np.ones((shape[0],)) * t
sample = x[:-shape[0]]
logp = x[-shape[0]:]
g = diffusion_coeff(torch.tensor(t)).cpu().numpy()
sample_grad = -0.5 * g ** 2 * score_eval_wrapper(sample, time_steps)
logp_grad = -0.5 * g ** 2 * divergence_eval_wrapper(sample, time_steps)
return np.concatenate([sample_grad, logp_grad], axis=0)
init = np.concatenate([x.cpu().numpy().reshape((-1,)), np.zeros((shape[0],))], axis=0)
# Black-box ODE solver
res = integrate.solve_ivp(ode_func, (eps, 1.), init, rtol=1e-5, atol=1e-5, method='RK45')
zp = torch.tensor(res.y[:, -1], device=device)
z = zp[:-shape[0]].reshape(shape)
delta_logp = zp[-shape[0]:].reshape(shape[0])
sigma_max = marginal_prob_std(1.)
prior_logp = prior_likelihood(z, sigma_max)
bpd = -(prior_logp + delta_logp) / np.log(2)
N = np.prod(shape[1:])
bpd = bpd / N + 8.
return z, bpd