rnn layer needs to be hyper connected too

This commit is contained in:
lucidrains 2025-11-10 15:51:33 -08:00
parent d5b70e2b86
commit 3c84b404a8
2 changed files with 8 additions and 14 deletions

View File

@ -1471,7 +1471,7 @@ class AxialSpaceTimeTransformer(Module):
special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
final_norm = True,
value_residual = True, # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS
rnn_time = False
rnn_time = True
):
super().__init__()
assert depth >= time_block_every, f'depth must be at least {time_block_every}'
@ -1531,10 +1531,7 @@ class AxialSpaceTimeTransformer(Module):
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
]))
rnn_layers.append(ModuleList([
nn.RMSNorm(dim),
nn.GRU(dim, dim, batch_first = True)
]) if is_time_block and rnn_time else None)
rnn_layers.append(hyper_conn(branch = nn.Sequential(nn.RMSNorm(dim), nn.GRU(dim, dim, batch_first = True))) if is_time_block and rnn_time else None)
self.layers = ModuleList(layers)
self.rnn_layers = ModuleList(rnn_layers)
@ -1618,22 +1615,19 @@ class AxialSpaceTimeTransformer(Module):
tokens = self.expand_streams(tokens)
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), maybe_rnn_modules, layer_is_time in zip(self.layers, self.rnn_layers, self.is_time):
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), maybe_rnn, layer_is_time in zip(self.layers, self.rnn_layers, self.is_time):
tokens = pre_attn_rearrange(tokens)
# maybe rnn for time
if layer_is_time and exists(maybe_rnn_modules):
rnn_prenorm, rnn = maybe_rnn_modules
if layer_is_time and exists(maybe_rnn):
rnn_input, inverse_pack_time = pack_one(tokens, '* t d')
tokens, inverse_pack_batch = pack_one(tokens, '* t d')
rnn_out, rnn_hiddens = rnn(rnn_prenorm(rnn_input)) # todo, handle rnn cache
tokens, rnn_hiddens = maybe_rnn(tokens) # todo, handle rnn cache
rnn_out = inverse_pack_time(rnn_out)
tokens = rnn_out + tokens
tokens = inverse_pack_batch(tokens)
# when is a axial time attention block, should be causal

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.1.16"
version = "0.1.17"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }