From c056835aea9d4587b7c2c958a87e398da205e9f0 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 8 Oct 2025 05:55:22 -0700 Subject: [PATCH] address https://github.com/lucidrains/dreamer4/issues/2 --- dreamer4/dreamer4.py | 4 ++-- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 802942e..2f1c20f 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -911,7 +911,7 @@ class VideoTokenizer(Module): # decoder attend - decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = True) + decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, causal = True, num_special_tokens = self.num_latent_tokens, special_attend_only_itself = True) # decoder attention @@ -1009,7 +1009,7 @@ class VideoTokenizer(Module): # modality can only attend to itself while latents can attend to everything # similar to agent token in dynamics model - encoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = False) + encoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, causal = True, num_special_tokens = self.num_latent_tokens, special_attend_only_itself = False) # encoder diff --git a/pyproject.toml b/pyproject.toml index 8e92774..58b8a7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.4" +version = "0.0.5" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }