allow for adding an RNN before time attention, but need to handle caching still
This commit is contained in:
parent
c3532fa797
commit
d5b70e2b86
@ -1470,7 +1470,8 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
num_special_spatial_tokens = 1,
|
num_special_spatial_tokens = 1,
|
||||||
special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
|
special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
|
||||||
final_norm = True,
|
final_norm = True,
|
||||||
value_residual = True # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS
|
value_residual = True, # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS
|
||||||
|
rnn_time = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert depth >= time_block_every, f'depth must be at least {time_block_every}'
|
assert depth >= time_block_every, f'depth must be at least {time_block_every}'
|
||||||
@ -1504,6 +1505,11 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
Rearrange('... (h d) -> ... h d', h = attn_heads)
|
Rearrange('... (h d) -> ... h d', h = attn_heads)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# a gru layer across time
|
||||||
|
|
||||||
|
self.rnn_time = rnn_time
|
||||||
|
rnn_layers = []
|
||||||
|
|
||||||
# transformer
|
# transformer
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
@ -1525,7 +1531,14 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
||||||
]))
|
]))
|
||||||
|
|
||||||
|
rnn_layers.append(ModuleList([
|
||||||
|
nn.RMSNorm(dim),
|
||||||
|
nn.GRU(dim, dim, batch_first = True)
|
||||||
|
]) if is_time_block and rnn_time else None)
|
||||||
|
|
||||||
self.layers = ModuleList(layers)
|
self.layers = ModuleList(layers)
|
||||||
|
self.rnn_layers = ModuleList(rnn_layers)
|
||||||
|
|
||||||
self.is_time = is_time
|
self.is_time = is_time
|
||||||
|
|
||||||
# final norm
|
# final norm
|
||||||
@ -1605,10 +1618,23 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
|
|
||||||
tokens = self.expand_streams(tokens)
|
tokens = self.expand_streams(tokens)
|
||||||
|
|
||||||
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), maybe_rnn_modules, layer_is_time in zip(self.layers, self.rnn_layers, self.is_time):
|
||||||
|
|
||||||
tokens = pre_attn_rearrange(tokens)
|
tokens = pre_attn_rearrange(tokens)
|
||||||
|
|
||||||
|
# maybe rnn for time
|
||||||
|
|
||||||
|
if layer_is_time and exists(maybe_rnn_modules):
|
||||||
|
rnn_prenorm, rnn = maybe_rnn_modules
|
||||||
|
|
||||||
|
rnn_input, inverse_pack_time = pack_one(tokens, '* t d')
|
||||||
|
|
||||||
|
rnn_out, rnn_hiddens = rnn(rnn_prenorm(rnn_input)) # todo, handle rnn cache
|
||||||
|
|
||||||
|
rnn_out = inverse_pack_time(rnn_out)
|
||||||
|
|
||||||
|
tokens = rnn_out + tokens
|
||||||
|
|
||||||
# when is a axial time attention block, should be causal
|
# when is a axial time attention block, should be causal
|
||||||
|
|
||||||
attend_fn = time_attend if layer_is_time else space_attend
|
attend_fn = time_attend if layer_is_time else space_attend
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.1.15"
|
version = "0.1.16"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user