84 lines
2.1 KiB
Python
84 lines
2.1 KiB
Python
import torch
|
|
from matplotlib import pyplot as plt
|
|
|
|
|
|
def apply_hard_conditioning(x, conditions):
|
|
for t, val in conditions.items():
|
|
x[:, t, :] = val.clone()
|
|
return x
|
|
|
|
|
|
def extract(a, t, x_shape):
|
|
b, *_ = t.shape
|
|
out = a.gather(-1, t)
|
|
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
|
|
|
|
|
@torch.no_grad()
|
|
def ddpm_sample_fn(
|
|
model, x, hard_conds, context, t,
|
|
guide=None,
|
|
n_guide_steps=1,
|
|
scale_grad_by_std=False,
|
|
t_start_guide=torch.inf,
|
|
noise_std_extra_schedule_fn=None, # 'linear'
|
|
debug=False,
|
|
**kwargs
|
|
):
|
|
t_single = t[0]
|
|
if t_single < 0:
|
|
t = torch.zeros_like(t)
|
|
|
|
model_mean, _, model_log_variance = model.p_mean_variance(x=x, hard_conds=hard_conds, context=context, t=t)
|
|
x = model_mean
|
|
|
|
model_log_variance = extract(model.posterior_log_variance_clipped, t, x.shape)
|
|
model_std = torch.exp(0.5 * model_log_variance)
|
|
model_var = torch.exp(model_log_variance)
|
|
|
|
if guide is not None and t_single < t_start_guide:
|
|
x = guide_gradient_steps(
|
|
x,
|
|
hard_conds=hard_conds,
|
|
guide=guide,
|
|
n_guide_steps=n_guide_steps,
|
|
scale_grad_by_std=scale_grad_by_std,
|
|
model_var=model_var,
|
|
debug=False,
|
|
)
|
|
|
|
# no noise when t == 0
|
|
noise = torch.randn_like(x)
|
|
noise[t == 0] = 0
|
|
|
|
# For smoother results, we can decay the noise standard deviation throughout the diffusion
|
|
# this is roughly equivalent to using a temperature in the prior distribution
|
|
if noise_std_extra_schedule_fn is None:
|
|
noise_std = 1.0
|
|
else:
|
|
noise_std = noise_std_extra_schedule_fn(t_single)
|
|
|
|
values = None
|
|
return x + model_std * noise * noise_std, values
|
|
|
|
|
|
def guide_gradient_steps(
|
|
x,
|
|
hard_conds=None,
|
|
guide=None,
|
|
n_guide_steps=1, scale_grad_by_std=False,
|
|
model_var=None,
|
|
debug=False,
|
|
**kwargs
|
|
):
|
|
for _ in range(n_guide_steps):
|
|
grad_scaled = guide(x)
|
|
|
|
if scale_grad_by_std:
|
|
grad_scaled = model_var * grad_scaled
|
|
|
|
x = x + grad_scaled
|
|
x = apply_hard_conditioning(x, hard_conds)
|
|
|
|
return x
|