From 0c4224da181c53812ce4a253cff4a5a5d6a97da8 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 9 Nov 2025 09:47:33 -0800 Subject: [PATCH] add a decorrelation loss for temporal attention in encoder of video tokenizer --- dreamer4/dreamer4.py | 63 ++++++++++++++++++++++++++++++++----------- pyproject.toml | 5 ++-- tests/test_dreamer.py | 4 ++- 3 files changed, 54 insertions(+), 18 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index b017ea7..d857871 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -27,6 +27,8 @@ from x_mlps_pytorch.normed_mlp import create_mlp from hyper_connections import get_init_and_expand_reduce_stream_functions +from vit_pytorch.vit_with_decorr import DecorrelationLoss + from assoc_scan import AssocScan # ein related @@ -68,10 +70,12 @@ except ImportError: LinearNoBias = partial(Linear, bias = False) -TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips')) +TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr')) WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions')) +AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs')) + MaybeTensor = Tensor | None @dataclass @@ -1321,7 +1325,7 @@ class Attention(Module): self, tokens, # (b n d) kv_cache = None, - return_kv_cache = False, + return_intermediates = False, rotary_pos_emb = None, attend_fn: Callable | None = None ): @@ -1375,10 +1379,10 @@ class Attention(Module): out = inverse_packed_batch(out) - if not return_kv_cache: + if not return_intermediates: return out - return out, stack((k, v)) + return out, AttentionIntermediates(stack((k, v)), tokens) # feedforward @@ -1492,7 +1496,7 @@ class AxialSpaceTimeTransformer(Module): self, tokens, # (b t s d) kv_cache: Tensor | None = None, # (y 2 b h t d) - return_kv_cache = False + return_intermediates = False ): # (b t s d) | (y 2 b h t d) @@ -1533,6 +1537,10 @@ class AxialSpaceTimeTransformer(Module): rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset) + # normed attention inputs + + normed_time_attn_inputs = [] + # attention tokens = self.expand_streams(tokens) @@ -1553,12 +1561,12 @@ class AxialSpaceTimeTransformer(Module): # attention layer - tokens, next_kv_cache = attn( + tokens, attn_intermediates = attn( tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn, kv_cache = maybe_kv_cache, - return_kv_cache = True + return_intermediates = True ) tokens = post_attn_rearrange(tokens) @@ -1570,7 +1578,12 @@ class AxialSpaceTimeTransformer(Module): # save kv cache if is time layer if layer_is_time: - time_attn_kv_caches.append(next_kv_cache) + time_attn_kv_caches.append(attn_intermediates.next_kv_cache) + + # save time attention inputs for decorr + + if layer_is_time: + normed_time_attn_inputs.append(attn_intermediates.normed_inputs) tokens = self.reduce_streams(tokens) @@ -1580,10 +1593,15 @@ class AxialSpaceTimeTransformer(Module): # just concat the past tokens back on for now, todo - clean up the logic out = cat((past_tokens, out), dim = 1) - if not return_kv_cache: + if not return_intermediates: return out - return out, stack(time_attn_kv_caches) + intermediates = ( + stack(time_attn_kv_caches), + stack(normed_time_attn_inputs) + ) + + return out, intermediates # video tokenizer @@ -1609,12 +1627,15 @@ class VideoTokenizer(Module): per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue lpips_loss_network: Module | None = None, lpips_loss_weight = 0.2, + encoder_add_decor_aux_loss = False, + decor_auxx_loss_weight = 0.1, + decorr_sample_frac = 0.25, nd_rotary_kwargs: dict = dict( rope_min_freq = 1., rope_max_freq = 10000., rope_p_zero_freqs = 0. ), - num_residual_streams = 1 + num_residual_streams = 1, ): super().__init__() @@ -1709,6 +1730,14 @@ class VideoTokenizer(Module): if self.has_lpips_loss: self.lpips = LPIPSLoss(lpips_loss_network) + # decorr aux loss + # https://arxiv.org/abs/2510.14657 + + self.encoder_add_decor_aux_loss = encoder_add_decor_aux_loss + self.decorr_aux_loss_weight = decor_auxx_loss_weight + + self.decorr_loss = DecorrelationLoss(decorr_sample_frac) if encoder_add_decor_aux_loss else None + @property def device(self): return self.zero.device @@ -1822,7 +1851,7 @@ class VideoTokenizer(Module): # encoder attention - tokens = self.encoder_transformer(tokens) + tokens, (_, time_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True) # latent bottleneck @@ -1844,17 +1873,21 @@ class VideoTokenizer(Module): if self.has_lpips_loss: lpips_loss = self.lpips(video, recon_video) + if self.encoder_add_decor_aux_loss: + decorr_loss = self.decorr_loss(time_attn_normed_inputs) + # losses total_loss = ( recon_loss + - lpips_loss * self.lpips_loss_weight + lpips_loss * self.lpips_loss_weight + + decorr_loss * self.decorr_aux_loss_weight ) if not return_all_losses: return total_loss - losses = (recon_loss, lpips_loss) + losses = (recon_loss, lpips_loss, decorr_loss) return total_loss, TokenizerLosses(*losses) @@ -3335,7 +3368,7 @@ class DynamicsWorldModel(Module): # attention - tokens, next_time_kv_cache = self.transformer(tokens, kv_cache = time_kv_cache, return_kv_cache = True) + tokens, (next_time_kv_cache, _) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True) # unpack diff --git a/pyproject.toml b/pyproject.toml index 2746a1f..b308aa4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.1.5" +version = "0.1.6" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } @@ -36,7 +36,8 @@ dependencies = [ "hyper-connections>=0.2.1", "torch>=2.4", "torchvision", - "x-mlps-pytorch>=0.0.29" + "x-mlps-pytorch>=0.0.29", + "vit-pytorch>=1.15.2" ] [project.urls] diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 15156fe..4e67818 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -41,7 +41,9 @@ def test_e2e( patch_size = 32, attn_dim_head = 16, num_latent_tokens = 4, - num_residual_streams = num_residual_streams + num_residual_streams = num_residual_streams, + encoder_add_decor_aux_loss = True, + decorr_sample_frac = 1. ) video = torch.randn(2, 3, 4, 256, 256)