2023-10-23 15:45:14 +02:00

108 lines
3.5 KiB
Python

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch_robotics.torch_utils.torch_utils import to_torch
#-----------------------------------------------------------------------------#
#---------------------------- variance schedules -----------------------------#
#-----------------------------------------------------------------------------#
def linear_beta_schedule(n_diffusion_steps, beta_start=0.0001, beta_end=0.02):
return torch.linspace(beta_start, beta_end, n_diffusion_steps)
def quadratic_beta_schedule(n_diffusion_steps, beta_start=0.0001, beta_end=0.02):
return torch.linspace(beta_start**0.5, beta_end**0.5, n_diffusion_steps) ** 2
def sigmoid_beta_schedule(n_diffusion_steps, beta_start=0.0001, beta_end=0.02):
betas = torch.linspace(-6, 6, n_diffusion_steps)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
def cosine_beta_schedule(n_diffusion_steps, s=0.008, a_min=0, a_max=0.999, dtype=torch.float32):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = n_diffusion_steps + 1
x = np.linspace(0, steps, steps)
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
betas_clipped = np.clip(betas, a_min=a_min, a_max=a_max)
return to_torch(betas_clipped, dtype=dtype)
def exponential_beta_schedule(n_diffusion_steps, beta_start=1e-4, beta_end=1.0):
# exponential increasing noise from t=0 to t=T
x = torch.linspace(0, n_diffusion_steps, n_diffusion_steps)
beta_start = to_torch(beta_start)
beta_end = to_torch(beta_end)
a = 1 / n_diffusion_steps * torch.log(beta_end / beta_start)
return beta_start * torch.exp(a * x)
def constant_fraction_beta_schedule(n_diffusion_steps):
# exponential increasing noise from t=0 to t=T
x = torch.linspace(0, n_diffusion_steps, n_diffusion_steps)
return 1 / (n_diffusion_steps-x+1)
def variance_preserving_beta_schedule(n_diffusion_steps, beta_start=1e-4, beta_end=1.0):
# Works only with a small number of diffusion steps
# https://arxiv.org/abs/2112.07804
# https://openreview.net/pdf?id=AHvFDPi-FA
x = torch.linspace(0, n_diffusion_steps, n_diffusion_steps)
alphas = torch.exp(-beta_start*(1/n_diffusion_steps) - 0.5*(beta_end-beta_start)*(2*x-1)/(n_diffusion_steps**2))
betas = 1 - alphas
return betas
#-----------------------------------------------------------------------------#
#---------------------------------- losses -----------------------------------#
#-----------------------------------------------------------------------------#
class WeightedLoss(nn.Module):
def __init__(self, weights=None):
super().__init__()
self.register_buffer('weights', weights)
def forward(self, pred, targ):
'''
pred, targ : tensor
[ batch_size x horizon x transition_dim ]
'''
loss = self._loss(pred, targ)
if self.weights is not None:
weighted_loss = (loss * self.weights).mean()
else:
weighted_loss = loss.mean()
return weighted_loss, {}
class WeightedL1(WeightedLoss):
def _loss(self, pred, targ):
return torch.abs(pred - targ)
class WeightedL2(WeightedLoss):
def _loss(self, pred, targ):
return F.mse_loss(pred, targ, reduction='none')
Losses = {
'l1': WeightedL1,
'l2': WeightedL2,
}