From 0a26e0f92fb965009edbea8702c156445621d2c6 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 4 Oct 2025 07:47:23 -0700 Subject: [PATCH] complete the lpips loss used for the video tokenizer --- dreamer4/dreamer4.py | 86 ++++++++++++++++++++++++++++++++++++++++++-- pyproject.toml | 1 + 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index c938db8..56f7fed 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -9,6 +9,9 @@ import torch.nn.functional as F from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity from torch import cat, stack, arange, tensor, Tensor, is_tensor +import torchvision +from torchvision.models import VGG16_Weights + from x_mlps_pytorch import create_mlp from accelerate import Accelerator @@ -76,6 +79,55 @@ def softclamp(t, value = 50.): # loss related +class LPIPSLoss(Module): + def __init__( + self, + vgg: Module | None = None, + vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT, + sampled_frames = 1 + ): + super().__init__() + + if not exists(vgg): + vgg = torchvision.models.vgg16(weights = vgg_weights) + vgg.classifier = Sequential(*vgg.classifier[:-2]) + + self.vgg = [vgg] + self.sampled_frames = sampled_frames + + def forward( + self, + pred, + data, + ): + batch, device, is_video = pred.shape[0], pred.device, pred.ndim == 5 + + vgg, = self.vgg + vgg = vgg.to(data.device) + + # take care of sampling random frames of the video + + if is_video: + pred, data = tuple(rearrange(t, 'b c t ... -> b t c ...') for t in (pred, data)) + + # batch randperm + + batch_randperm = torch.randn(pred.shape[:2], device = pred.device).argsort(dim = -1) + rand_frames = batch_randperm[..., :self.sampled_frames] + + batch_arange = arange(batch, device = device) + batch_arange = rearrange(batch_arange, '... -> ... 1') + + pred, data = tuple(t[batch_arange, rand_frames] for t in (pred, data)) + + # fold sampled frames into batch + + pred, data = tuple(rearrange(t, 'b t c ... -> (b t) c ...') for t in (pred, data)) + + pred_embed, embed = tuple(vgg(t) for t in (pred, data)) + + return F.mse_loss(embed, pred_embed) + def ramp_weight(times, slope = 0.9, intercept = 0.1): # equation (8) paper, their "ramp" loss weighting return slope * times + intercept @@ -508,6 +560,8 @@ class VideoTokenizer(Module): decoder_pos_mlp_depth = 2, channels = 3, per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue + lpips_loss_network: Module | None = None, + lpips_loss_weight = 0.2 ): super().__init__() @@ -580,6 +634,16 @@ class VideoTokenizer(Module): self.decoder_layers = ModuleList(decoder_layers) self.decoder_norm = RMSNorm(dim) + # loss related + + self.register_buffer('zero', tensor(0.), persistent = False) + + self.has_lpips_loss = lpips_loss_weight > 0. + self.lpips_loss_weight = lpips_loss_weight + + if self.has_lpips_loss: + self.lpips = LPIPSLoss(lpips_loss_network) + @torch.no_grad() def tokenize( self, @@ -592,7 +656,8 @@ class VideoTokenizer(Module): self, video, # (b c t h w) return_latents = False, - mask_patches = None + mask_patches = None, + return_all_losses = False ): batch, time = video.shape[0], video.shape[2] patch_size, device = self.patch_size, video.device @@ -697,7 +762,24 @@ class VideoTokenizer(Module): recon_loss = F.mse_loss(video, recon_video) - return recon_loss + lpips_loss = self.zero + + if self.has_lpips_loss: + lpips_loss = self.lpips(video, recon_video) + + # losses + + total_loss = ( + recon_loss + + lpips_loss * self.lpips_loss_weight + ) + + if not return_all_losses: + return total_loss + + losses = (recon_loss, lpips_loss) + + return total_loss, losses # dynamics model, axial space-time transformer diff --git a/pyproject.toml b/pyproject.toml index 215d378..f489271 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "einops>=0.8.1", "hl-gauss-pytorch", "torch>=2.4", + "torchvision", "x-mlps-pytorch" ]