add a decorrelation loss for temporal attention in encoder of video tokenizer

This commit is contained in:
lucidrains 2025-11-09 09:47:33 -08:00
parent 256a81f658
commit 0c4224da18
3 changed files with 54 additions and 18 deletions

View File

@ -27,6 +27,8 @@ from x_mlps_pytorch.normed_mlp import create_mlp
from hyper_connections import get_init_and_expand_reduce_stream_functions
from vit_pytorch.vit_with_decorr import DecorrelationLoss
from assoc_scan import AssocScan
# ein related
@ -68,10 +70,12 @@ except ImportError:
LinearNoBias = partial(Linear, bias = False)
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr'))
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
MaybeTensor = Tensor | None
@dataclass
@ -1321,7 +1325,7 @@ class Attention(Module):
self,
tokens, # (b n d)
kv_cache = None,
return_kv_cache = False,
return_intermediates = False,
rotary_pos_emb = None,
attend_fn: Callable | None = None
):
@ -1375,10 +1379,10 @@ class Attention(Module):
out = inverse_packed_batch(out)
if not return_kv_cache:
if not return_intermediates:
return out
return out, stack((k, v))
return out, AttentionIntermediates(stack((k, v)), tokens)
# feedforward
@ -1492,7 +1496,7 @@ class AxialSpaceTimeTransformer(Module):
self,
tokens, # (b t s d)
kv_cache: Tensor | None = None, # (y 2 b h t d)
return_kv_cache = False
return_intermediates = False
): # (b t s d) | (y 2 b h t d)
@ -1533,6 +1537,10 @@ class AxialSpaceTimeTransformer(Module):
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
# normed attention inputs
normed_time_attn_inputs = []
# attention
tokens = self.expand_streams(tokens)
@ -1553,12 +1561,12 @@ class AxialSpaceTimeTransformer(Module):
# attention layer
tokens, next_kv_cache = attn(
tokens, attn_intermediates = attn(
tokens,
rotary_pos_emb = layer_rotary_pos_emb,
attend_fn = attend_fn,
kv_cache = maybe_kv_cache,
return_kv_cache = True
return_intermediates = True
)
tokens = post_attn_rearrange(tokens)
@ -1570,7 +1578,12 @@ class AxialSpaceTimeTransformer(Module):
# save kv cache if is time layer
if layer_is_time:
time_attn_kv_caches.append(next_kv_cache)
time_attn_kv_caches.append(attn_intermediates.next_kv_cache)
# save time attention inputs for decorr
if layer_is_time:
normed_time_attn_inputs.append(attn_intermediates.normed_inputs)
tokens = self.reduce_streams(tokens)
@ -1580,10 +1593,15 @@ class AxialSpaceTimeTransformer(Module):
# just concat the past tokens back on for now, todo - clean up the logic
out = cat((past_tokens, out), dim = 1)
if not return_kv_cache:
if not return_intermediates:
return out
return out, stack(time_attn_kv_caches)
intermediates = (
stack(time_attn_kv_caches),
stack(normed_time_attn_inputs)
)
return out, intermediates
# video tokenizer
@ -1609,12 +1627,15 @@ class VideoTokenizer(Module):
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
lpips_loss_network: Module | None = None,
lpips_loss_weight = 0.2,
encoder_add_decor_aux_loss = False,
decor_auxx_loss_weight = 0.1,
decorr_sample_frac = 0.25,
nd_rotary_kwargs: dict = dict(
rope_min_freq = 1.,
rope_max_freq = 10000.,
rope_p_zero_freqs = 0.
),
num_residual_streams = 1
num_residual_streams = 1,
):
super().__init__()
@ -1709,6 +1730,14 @@ class VideoTokenizer(Module):
if self.has_lpips_loss:
self.lpips = LPIPSLoss(lpips_loss_network)
# decorr aux loss
# https://arxiv.org/abs/2510.14657
self.encoder_add_decor_aux_loss = encoder_add_decor_aux_loss
self.decorr_aux_loss_weight = decor_auxx_loss_weight
self.decorr_loss = DecorrelationLoss(decorr_sample_frac) if encoder_add_decor_aux_loss else None
@property
def device(self):
return self.zero.device
@ -1822,7 +1851,7 @@ class VideoTokenizer(Module):
# encoder attention
tokens = self.encoder_transformer(tokens)
tokens, (_, time_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
# latent bottleneck
@ -1844,17 +1873,21 @@ class VideoTokenizer(Module):
if self.has_lpips_loss:
lpips_loss = self.lpips(video, recon_video)
if self.encoder_add_decor_aux_loss:
decorr_loss = self.decorr_loss(time_attn_normed_inputs)
# losses
total_loss = (
recon_loss +
lpips_loss * self.lpips_loss_weight
lpips_loss * self.lpips_loss_weight +
decorr_loss * self.decorr_aux_loss_weight
)
if not return_all_losses:
return total_loss
losses = (recon_loss, lpips_loss)
losses = (recon_loss, lpips_loss, decorr_loss)
return total_loss, TokenizerLosses(*losses)
@ -3335,7 +3368,7 @@ class DynamicsWorldModel(Module):
# attention
tokens, next_time_kv_cache = self.transformer(tokens, kv_cache = time_kv_cache, return_kv_cache = True)
tokens, (next_time_kv_cache, _) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
# unpack

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.1.5"
version = "0.1.6"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
@ -36,7 +36,8 @@ dependencies = [
"hyper-connections>=0.2.1",
"torch>=2.4",
"torchvision",
"x-mlps-pytorch>=0.0.29"
"x-mlps-pytorch>=0.0.29",
"vit-pytorch>=1.15.2"
]
[project.urls]

View File

@ -41,7 +41,9 @@ def test_e2e(
patch_size = 32,
attn_dim_head = 16,
num_latent_tokens = 4,
num_residual_streams = num_residual_streams
num_residual_streams = num_residual_streams,
encoder_add_decor_aux_loss = True,
decorr_sample_frac = 1.
)
video = torch.randn(2, 3, 4, 256, 256)