complete the lpips loss used for the video tokenizer

This commit is contained in:
lucidrains 2025-10-04 07:47:23 -07:00
parent 92e55a90b4
commit 0a26e0f92f
2 changed files with 85 additions and 2 deletions

View File

@ -9,6 +9,9 @@ import torch.nn.functional as F
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import cat, stack, arange, tensor, Tensor, is_tensor
import torchvision
from torchvision.models import VGG16_Weights
from x_mlps_pytorch import create_mlp
from accelerate import Accelerator
@ -76,6 +79,55 @@ def softclamp(t, value = 50.):
# loss related
class LPIPSLoss(Module):
def __init__(
self,
vgg: Module | None = None,
vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
sampled_frames = 1
):
super().__init__()
if not exists(vgg):
vgg = torchvision.models.vgg16(weights = vgg_weights)
vgg.classifier = Sequential(*vgg.classifier[:-2])
self.vgg = [vgg]
self.sampled_frames = sampled_frames
def forward(
self,
pred,
data,
):
batch, device, is_video = pred.shape[0], pred.device, pred.ndim == 5
vgg, = self.vgg
vgg = vgg.to(data.device)
# take care of sampling random frames of the video
if is_video:
pred, data = tuple(rearrange(t, 'b c t ... -> b t c ...') for t in (pred, data))
# batch randperm
batch_randperm = torch.randn(pred.shape[:2], device = pred.device).argsort(dim = -1)
rand_frames = batch_randperm[..., :self.sampled_frames]
batch_arange = arange(batch, device = device)
batch_arange = rearrange(batch_arange, '... -> ... 1')
pred, data = tuple(t[batch_arange, rand_frames] for t in (pred, data))
# fold sampled frames into batch
pred, data = tuple(rearrange(t, 'b t c ... -> (b t) c ...') for t in (pred, data))
pred_embed, embed = tuple(vgg(t) for t in (pred, data))
return F.mse_loss(embed, pred_embed)
def ramp_weight(times, slope = 0.9, intercept = 0.1):
# equation (8) paper, their "ramp" loss weighting
return slope * times + intercept
@ -508,6 +560,8 @@ class VideoTokenizer(Module):
decoder_pos_mlp_depth = 2,
channels = 3,
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
lpips_loss_network: Module | None = None,
lpips_loss_weight = 0.2
):
super().__init__()
@ -580,6 +634,16 @@ class VideoTokenizer(Module):
self.decoder_layers = ModuleList(decoder_layers)
self.decoder_norm = RMSNorm(dim)
# loss related
self.register_buffer('zero', tensor(0.), persistent = False)
self.has_lpips_loss = lpips_loss_weight > 0.
self.lpips_loss_weight = lpips_loss_weight
if self.has_lpips_loss:
self.lpips = LPIPSLoss(lpips_loss_network)
@torch.no_grad()
def tokenize(
self,
@ -592,7 +656,8 @@ class VideoTokenizer(Module):
self,
video, # (b c t h w)
return_latents = False,
mask_patches = None
mask_patches = None,
return_all_losses = False
):
batch, time = video.shape[0], video.shape[2]
patch_size, device = self.patch_size, video.device
@ -697,7 +762,24 @@ class VideoTokenizer(Module):
recon_loss = F.mse_loss(video, recon_video)
return recon_loss
lpips_loss = self.zero
if self.has_lpips_loss:
lpips_loss = self.lpips(video, recon_video)
# losses
total_loss = (
recon_loss +
lpips_loss * self.lpips_loss_weight
)
if not return_all_losses:
return total_loss
losses = (recon_loss, lpips_loss)
return total_loss, losses
# dynamics model, axial space-time transformer

View File

@ -32,6 +32,7 @@ dependencies = [
"einops>=0.8.1",
"hl-gauss-pytorch",
"torch>=2.4",
"torchvision",
"x-mlps-pytorch"
]