359 lines
13 KiB
Python
359 lines
13 KiB
Python
"""
|
|
Adapted from https://github.com/jannerm/diffuser
|
|
"""
|
|
import abc
|
|
import time
|
|
from collections import namedtuple
|
|
from copy import copy
|
|
|
|
import einops
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from abc import ABC
|
|
|
|
from torch.nn import DataParallel
|
|
|
|
from mpd.models.diffusion_models.helpers import cosine_beta_schedule, Losses, exponential_beta_schedule
|
|
from mpd.models.diffusion_models.sample_functions import extract, apply_hard_conditioning, guide_gradient_steps, \
|
|
ddpm_sample_fn
|
|
from torch_robotics.torch_utils.torch_timer import TimerCUDA
|
|
from torch_robotics.torch_utils.torch_utils import to_numpy
|
|
|
|
|
|
def make_timesteps(batch_size, i, device):
|
|
t = torch.full((batch_size,), i, device=device, dtype=torch.long)
|
|
return t
|
|
|
|
|
|
def build_context(model, dataset, input_dict):
|
|
# input_dict is already normalized
|
|
context = None
|
|
if model.context_model is not None:
|
|
context = dict()
|
|
# (normalized) features of variable environments
|
|
if dataset.variable_environment:
|
|
env_normalized = input_dict[f'{dataset.field_key_env}_normalized']
|
|
context['env'] = env_normalized
|
|
|
|
# tasks
|
|
task_normalized = input_dict[f'{dataset.field_key_task}_normalized']
|
|
context['tasks'] = task_normalized
|
|
return context
|
|
|
|
|
|
class GaussianDiffusionModel(nn.Module, ABC):
|
|
|
|
def __init__(self,
|
|
model=None,
|
|
variance_schedule='exponential',
|
|
n_diffusion_steps=100,
|
|
clip_denoised=True,
|
|
predict_epsilon=False,
|
|
loss_type='l2',
|
|
context_model=None,
|
|
**kwargs):
|
|
super().__init__()
|
|
|
|
self.model = model
|
|
|
|
self.context_model = context_model
|
|
|
|
self.n_diffusion_steps = n_diffusion_steps
|
|
|
|
self.state_dim = self.model.state_dim
|
|
|
|
if variance_schedule == 'cosine':
|
|
betas = cosine_beta_schedule(n_diffusion_steps, s=0.008, a_min=0, a_max=0.999)
|
|
elif variance_schedule == 'exponential':
|
|
betas = exponential_beta_schedule(n_diffusion_steps, beta_start=1e-4, beta_end=1.0)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
alphas = 1. - betas
|
|
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
|
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
|
|
|
|
self.clip_denoised = clip_denoised
|
|
self.predict_epsilon = predict_epsilon
|
|
|
|
self.register_buffer('betas', betas)
|
|
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
|
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
|
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
|
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
|
|
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
|
self.register_buffer('posterior_variance', posterior_variance)
|
|
|
|
## log calculation clipped because the posterior variance
|
|
## is 0 at the beginning of the diffusion chain
|
|
self.register_buffer('posterior_log_variance_clipped',
|
|
torch.log(torch.clamp(posterior_variance, min=1e-20)))
|
|
self.register_buffer('posterior_mean_coef1',
|
|
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
self.register_buffer('posterior_mean_coef2',
|
|
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
|
|
|
|
## get loss coefficients and initialize objective
|
|
self.loss_fn = Losses[loss_type]()
|
|
|
|
# ------------------------------------------ sampling ------------------------------------------#
|
|
def predict_noise_from_start(self, x_t, t, x0):
|
|
"""
|
|
if self.predict_epsilon, model output is (scaled) noise;
|
|
otherwise, model predicts x0 directly
|
|
"""
|
|
if self.predict_epsilon:
|
|
return x0
|
|
else:
|
|
return (
|
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0
|
|
) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
|
|
|
def predict_start_from_noise(self, x_t, t, noise):
|
|
'''
|
|
if self.predict_epsilon, model output is (scaled) noise;
|
|
otherwise, model predicts x0 directly
|
|
'''
|
|
if self.predict_epsilon:
|
|
return (
|
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
|
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
|
)
|
|
else:
|
|
return noise
|
|
|
|
def q_posterior(self, x_start, x_t, t):
|
|
posterior_mean = (
|
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
|
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
|
)
|
|
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
|
|
def p_mean_variance(self, x, hard_conds, context, t):
|
|
if context is not None:
|
|
context = self.context_model(context)
|
|
|
|
x_recon = self.predict_start_from_noise(x, t=t, noise=self.model(x, t, context))
|
|
|
|
if self.clip_denoised:
|
|
x_recon.clamp_(-1., 1.)
|
|
else:
|
|
assert RuntimeError()
|
|
|
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
|
return model_mean, posterior_variance, posterior_log_variance
|
|
|
|
@torch.no_grad()
|
|
def p_sample_loop(self, shape, hard_conds, context=None, return_chain=False,
|
|
sample_fn=ddpm_sample_fn,
|
|
n_diffusion_steps_without_noise=0,
|
|
**sample_kwargs):
|
|
device = self.betas.device
|
|
|
|
batch_size = shape[0]
|
|
x = torch.randn(shape, device=device)
|
|
x = apply_hard_conditioning(x, hard_conds)
|
|
|
|
chain = [x] if return_chain else None
|
|
|
|
for i in reversed(range(-n_diffusion_steps_without_noise, self.n_diffusion_steps)):
|
|
t = make_timesteps(batch_size, i, device)
|
|
x, values = sample_fn(self, x, hard_conds, context, t, **sample_kwargs)
|
|
x = apply_hard_conditioning(x, hard_conds)
|
|
|
|
if return_chain:
|
|
chain.append(x)
|
|
|
|
if return_chain:
|
|
chain = torch.stack(chain, dim=1)
|
|
return x, chain
|
|
|
|
return x
|
|
|
|
@torch.no_grad()
|
|
def ddim_sample(
|
|
self, shape, hard_conds,
|
|
context=None, return_chain=False,
|
|
t_start_guide=torch.inf,
|
|
guide=None,
|
|
n_guide_steps=1,
|
|
**sample_kwargs,
|
|
):
|
|
# Adapted from https://github.com/ezhang7423/language-control-diffusion/blob/63cdafb63d166221549968c662562753f6ac5394/src/lcd/models/diffusion.py#L226
|
|
device = self.betas.device
|
|
batch_size = shape[0]
|
|
total_timesteps = self.n_diffusion_steps
|
|
sampling_timesteps = self.n_diffusion_steps // 5
|
|
eta = 0.
|
|
|
|
# [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
|
|
times = torch.linspace(0, total_timesteps - 1, steps=sampling_timesteps + 1, device=device)
|
|
times = torch.cat((torch.tensor([-1], device=device), times))
|
|
times = list(reversed(times.int().tolist()))
|
|
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
|
|
|
|
x = torch.randn(shape, device=device)
|
|
x = apply_hard_conditioning(x, hard_conds)
|
|
|
|
chain = [x] if return_chain else None
|
|
|
|
for time, time_next in time_pairs:
|
|
t = make_timesteps(batch_size, time, device)
|
|
t_next = make_timesteps(batch_size, time_next, device)
|
|
|
|
model_out = self.model(x, t, context)
|
|
|
|
x_start = self.predict_start_from_noise(x, t=t, noise=model_out)
|
|
pred_noise = self.predict_noise_from_start(x, t=t, x0=model_out)
|
|
|
|
if time_next < 0:
|
|
x = x_start
|
|
x = apply_hard_conditioning(x, hard_conds)
|
|
if return_chain:
|
|
chain.append(x)
|
|
break
|
|
|
|
alpha = extract(self.alphas_cumprod, t, x.shape)
|
|
alpha_next = extract(self.alphas_cumprod, t_next, x.shape)
|
|
|
|
sigma = (
|
|
eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
|
)
|
|
c = (1 - alpha_next - sigma**2).sqrt()
|
|
|
|
x = x_start * alpha_next.sqrt() + c * pred_noise
|
|
|
|
# guide gradient steps before adding noise
|
|
if guide is not None:
|
|
if torch.all(t_next < t_start_guide):
|
|
x = guide_gradient_steps(
|
|
x,
|
|
hard_conds=hard_conds,
|
|
guide=guide,
|
|
**sample_kwargs
|
|
)
|
|
|
|
# add noise
|
|
noise = torch.randn_like(x)
|
|
x = x + sigma * noise
|
|
x = apply_hard_conditioning(x, hard_conds)
|
|
|
|
if return_chain:
|
|
chain.append(x)
|
|
|
|
if return_chain:
|
|
chain = torch.stack(chain, dim=1)
|
|
return x, chain
|
|
|
|
return x
|
|
|
|
@torch.no_grad()
|
|
def conditional_sample(self, hard_conds, horizon=None, batch_size=1, ddim=False, **sample_kwargs):
|
|
'''
|
|
hard conditions : hard_conds : { (time, state), ... }
|
|
'''
|
|
horizon = horizon or self.horizon
|
|
shape = (batch_size, horizon, self.state_dim)
|
|
|
|
if ddim:
|
|
return self.ddim_sample(shape, hard_conds, **sample_kwargs)
|
|
|
|
return self.p_sample_loop(shape, hard_conds, **sample_kwargs)
|
|
|
|
def forward(self, cond, *args, **kwargs):
|
|
raise NotImplementedError
|
|
return self.conditional_sample(cond, *args, **kwargs)
|
|
|
|
@torch.no_grad()
|
|
def warmup(self, horizon=64, device='cuda'):
|
|
shape = (2, horizon, self.state_dim)
|
|
x = torch.randn(shape, device=device)
|
|
t = make_timesteps(2, 1, device)
|
|
self.model(x, t, context=None)
|
|
|
|
@torch.no_grad()
|
|
def run_inference(self, context=None, hard_conds=None, n_samples=1, return_chain=False, **diffusion_kwargs):
|
|
# context and hard_conds must be normalized
|
|
hard_conds = copy(hard_conds)
|
|
context = copy(context)
|
|
|
|
# repeat hard conditions and contexts for n_samples
|
|
for k, v in hard_conds.items():
|
|
new_state = einops.repeat(v, 'd -> b d', b=n_samples)
|
|
hard_conds[k] = new_state
|
|
|
|
if context is not None:
|
|
for k, v in context.items():
|
|
context[k] = einops.repeat(v, 'd -> b d', b=n_samples)
|
|
|
|
# Sample from diffusion model
|
|
samples, chain = self.conditional_sample(
|
|
hard_conds, context=context, batch_size=n_samples, return_chain=True, **diffusion_kwargs
|
|
)
|
|
|
|
# chain: [ n_samples x (n_diffusion_steps + 1) x horizon x (state_dim)]
|
|
# extract normalized trajectories
|
|
trajs_chain_normalized = chain
|
|
|
|
# trajs: [ (n_diffusion_steps + 1) x n_samples x horizon x state_dim ]
|
|
trajs_chain_normalized = einops.rearrange(trajs_chain_normalized, 'b diffsteps h d -> diffsteps b h d')
|
|
|
|
if return_chain:
|
|
return trajs_chain_normalized
|
|
|
|
# return the last denoising step
|
|
return trajs_chain_normalized[-1]
|
|
|
|
# ------------------------------------------ training ------------------------------------------#
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
if noise is None:
|
|
noise = torch.randn_like(x_start)
|
|
|
|
sample = (
|
|
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
|
)
|
|
|
|
return sample
|
|
|
|
def p_losses(self, x_start, context, t, hard_conds):
|
|
noise = torch.randn_like(x_start)
|
|
|
|
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
|
x_noisy = apply_hard_conditioning(x_noisy, hard_conds)
|
|
|
|
# context model
|
|
if context is not None:
|
|
context = self.context_model(context)
|
|
|
|
# diffusion model
|
|
x_recon = self.model(x_noisy, t, context)
|
|
x_recon = apply_hard_conditioning(x_recon, hard_conds)
|
|
|
|
assert noise.shape == x_recon.shape
|
|
|
|
if self.predict_epsilon:
|
|
loss, info = self.loss_fn(x_recon, noise)
|
|
else:
|
|
loss, info = self.loss_fn(x_recon, x_start)
|
|
|
|
return loss, info
|
|
|
|
def loss(self, x, context, *args):
|
|
batch_size = x.shape[0]
|
|
t = torch.randint(0, self.n_diffusion_steps, (batch_size,), device=x.device).long()
|
|
return self.p_losses(x, context, t, *args)
|
|
|