From d5b70e2b866bff639c4e8383138b530246d4863a Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 10 Nov 2025 11:42:20 -0800 Subject: [PATCH] allow for adding an RNN before time attention, but need to handle caching still --- dreamer4/dreamer4.py | 30 ++++++++++++++++++++++++++++-- pyproject.toml | 2 +- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index e04ac5d..d6583f6 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1470,7 +1470,8 @@ class AxialSpaceTimeTransformer(Module): num_special_spatial_tokens = 1, special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything) final_norm = True, - value_residual = True # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS + value_residual = True, # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS + rnn_time = False ): super().__init__() assert depth >= time_block_every, f'depth must be at least {time_block_every}' @@ -1504,6 +1505,11 @@ class AxialSpaceTimeTransformer(Module): Rearrange('... (h d) -> ... h d', h = attn_heads) ) + # a gru layer across time + + self.rnn_time = rnn_time + rnn_layers = [] + # transformer layers = [] @@ -1525,7 +1531,14 @@ class AxialSpaceTimeTransformer(Module): hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs)) ])) + rnn_layers.append(ModuleList([ + nn.RMSNorm(dim), + nn.GRU(dim, dim, batch_first = True) + ]) if is_time_block and rnn_time else None) + self.layers = ModuleList(layers) + self.rnn_layers = ModuleList(rnn_layers) + self.is_time = is_time # final norm @@ -1605,10 +1618,23 @@ class AxialSpaceTimeTransformer(Module): tokens = self.expand_streams(tokens) - for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time): + for (pre_attn_rearrange, post_attn_rearrange, attn, ff), maybe_rnn_modules, layer_is_time in zip(self.layers, self.rnn_layers, self.is_time): tokens = pre_attn_rearrange(tokens) + # maybe rnn for time + + if layer_is_time and exists(maybe_rnn_modules): + rnn_prenorm, rnn = maybe_rnn_modules + + rnn_input, inverse_pack_time = pack_one(tokens, '* t d') + + rnn_out, rnn_hiddens = rnn(rnn_prenorm(rnn_input)) # todo, handle rnn cache + + rnn_out = inverse_pack_time(rnn_out) + + tokens = rnn_out + tokens + # when is a axial time attention block, should be causal attend_fn = time_attend if layer_is_time else space_attend diff --git a/pyproject.toml b/pyproject.toml index 3267566..477e76b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.1.15" +version = "0.1.16" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }