This commit is contained in:
lucidrains 2025-10-04 08:04:42 -07:00
parent 0f4783f23c
commit 7cac3d28c5

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import math
from collections import namedtuple
from functools import partial
import torch
@ -49,6 +50,8 @@ except ImportError:
LinearNoBias = partial(Linear, bias = False)
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
# helpers
def exists(v):
@ -780,7 +783,7 @@ class VideoTokenizer(Module):
losses = (recon_loss, lpips_loss)
return total_loss, losses
return total_loss, TokenizerLosses(losses)
# dynamics model, axial space-time transformer