From 2ccb290e26f45d653e100723d6b9ed0351b1b1d6 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 16 Oct 2025 08:33:26 -0700 Subject: [PATCH] pass the attend kwargs for the block causal masking in tokenizer --- dreamer4/dreamer4.py | 32 ++++++++++++++++++++++---------- pyproject.toml | 2 +- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 3727f64..062af52 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1278,6 +1278,8 @@ class VideoTokenizer(Module): tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d') + space_seq_len = tokens.shape[-2] + # pack time tokens, inverse_pack_time = pack_one(tokens, 'b * d') @@ -1286,7 +1288,16 @@ class VideoTokenizer(Module): # decoder attend - decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, causal = True, num_special_tokens = self.num_latent_tokens, special_attend_only_itself = True) + decoder_attend_fn = get_attend_fn( + use_flex, + seq_len, seq_len, + causal = True, + causal_block_size = space_seq_len, + softclamp_value = self.attn_softclamp_value, + block_size_per_special = space_seq_len, + num_special_tokens = self.num_latent_tokens, + special_attend_only_itself = True # different than encoder + ) # decoder attention @@ -1373,14 +1384,6 @@ class VideoTokenizer(Module): # attend hyper parameters - attend_kwargs = dict( - causal = True, - causal_block_size = space_seq_len, - softclamp_value = self.attn_softclamp_value, - block_size_per_special = space_seq_len, - num_special_tokens = 1 - ) - use_flex = tokens.is_cuda and exists(flex_attention) # encoder attend @@ -1388,7 +1391,16 @@ class VideoTokenizer(Module): # modality can only attend to itself while latents can attend to everything # similar to agent token in dynamics model - encoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, causal = True, num_special_tokens = self.num_latent_tokens, special_attend_only_itself = False) + encoder_attend_fn = get_attend_fn( + use_flex, + seq_len, seq_len, + causal = True, + causal_block_size = space_seq_len, + softclamp_value = self.attn_softclamp_value, + block_size_per_special = space_seq_len, + num_special_tokens = self.num_latent_tokens, + special_attend_only_itself = False # different than decoder + ) # encoder diff --git a/pyproject.toml b/pyproject.toml index 299d863..32e4200 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.22" +version = "0.0.23" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }