complete the lpips loss used for the video tokenizer
This commit is contained in:
parent
92e55a90b4
commit
0a26e0f92f
@ -9,6 +9,9 @@ import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||
from torch import cat, stack, arange, tensor, Tensor, is_tensor
|
||||
|
||||
import torchvision
|
||||
from torchvision.models import VGG16_Weights
|
||||
|
||||
from x_mlps_pytorch import create_mlp
|
||||
|
||||
from accelerate import Accelerator
|
||||
@ -76,6 +79,55 @@ def softclamp(t, value = 50.):
|
||||
|
||||
# loss related
|
||||
|
||||
class LPIPSLoss(Module):
|
||||
def __init__(
|
||||
self,
|
||||
vgg: Module | None = None,
|
||||
vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
|
||||
sampled_frames = 1
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not exists(vgg):
|
||||
vgg = torchvision.models.vgg16(weights = vgg_weights)
|
||||
vgg.classifier = Sequential(*vgg.classifier[:-2])
|
||||
|
||||
self.vgg = [vgg]
|
||||
self.sampled_frames = sampled_frames
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred,
|
||||
data,
|
||||
):
|
||||
batch, device, is_video = pred.shape[0], pred.device, pred.ndim == 5
|
||||
|
||||
vgg, = self.vgg
|
||||
vgg = vgg.to(data.device)
|
||||
|
||||
# take care of sampling random frames of the video
|
||||
|
||||
if is_video:
|
||||
pred, data = tuple(rearrange(t, 'b c t ... -> b t c ...') for t in (pred, data))
|
||||
|
||||
# batch randperm
|
||||
|
||||
batch_randperm = torch.randn(pred.shape[:2], device = pred.device).argsort(dim = -1)
|
||||
rand_frames = batch_randperm[..., :self.sampled_frames]
|
||||
|
||||
batch_arange = arange(batch, device = device)
|
||||
batch_arange = rearrange(batch_arange, '... -> ... 1')
|
||||
|
||||
pred, data = tuple(t[batch_arange, rand_frames] for t in (pred, data))
|
||||
|
||||
# fold sampled frames into batch
|
||||
|
||||
pred, data = tuple(rearrange(t, 'b t c ... -> (b t) c ...') for t in (pred, data))
|
||||
|
||||
pred_embed, embed = tuple(vgg(t) for t in (pred, data))
|
||||
|
||||
return F.mse_loss(embed, pred_embed)
|
||||
|
||||
def ramp_weight(times, slope = 0.9, intercept = 0.1):
|
||||
# equation (8) paper, their "ramp" loss weighting
|
||||
return slope * times + intercept
|
||||
@ -508,6 +560,8 @@ class VideoTokenizer(Module):
|
||||
decoder_pos_mlp_depth = 2,
|
||||
channels = 3,
|
||||
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
||||
lpips_loss_network: Module | None = None,
|
||||
lpips_loss_weight = 0.2
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -580,6 +634,16 @@ class VideoTokenizer(Module):
|
||||
self.decoder_layers = ModuleList(decoder_layers)
|
||||
self.decoder_norm = RMSNorm(dim)
|
||||
|
||||
# loss related
|
||||
|
||||
self.register_buffer('zero', tensor(0.), persistent = False)
|
||||
|
||||
self.has_lpips_loss = lpips_loss_weight > 0.
|
||||
self.lpips_loss_weight = lpips_loss_weight
|
||||
|
||||
if self.has_lpips_loss:
|
||||
self.lpips = LPIPSLoss(lpips_loss_network)
|
||||
|
||||
@torch.no_grad()
|
||||
def tokenize(
|
||||
self,
|
||||
@ -592,7 +656,8 @@ class VideoTokenizer(Module):
|
||||
self,
|
||||
video, # (b c t h w)
|
||||
return_latents = False,
|
||||
mask_patches = None
|
||||
mask_patches = None,
|
||||
return_all_losses = False
|
||||
):
|
||||
batch, time = video.shape[0], video.shape[2]
|
||||
patch_size, device = self.patch_size, video.device
|
||||
@ -697,7 +762,24 @@ class VideoTokenizer(Module):
|
||||
|
||||
recon_loss = F.mse_loss(video, recon_video)
|
||||
|
||||
return recon_loss
|
||||
lpips_loss = self.zero
|
||||
|
||||
if self.has_lpips_loss:
|
||||
lpips_loss = self.lpips(video, recon_video)
|
||||
|
||||
# losses
|
||||
|
||||
total_loss = (
|
||||
recon_loss +
|
||||
lpips_loss * self.lpips_loss_weight
|
||||
)
|
||||
|
||||
if not return_all_losses:
|
||||
return total_loss
|
||||
|
||||
losses = (recon_loss, lpips_loss)
|
||||
|
||||
return total_loss, losses
|
||||
|
||||
# dynamics model, axial space-time transformer
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ dependencies = [
|
||||
"einops>=0.8.1",
|
||||
"hl-gauss-pytorch",
|
||||
"torch>=2.4",
|
||||
"torchvision",
|
||||
"x-mlps-pytorch"
|
||||
]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user