diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index ac73ecf..10d953b 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -189,6 +189,13 @@ def sample_prob(prob): def is_power_two(num): return log2(num).is_integer() +def maybe(fn): + def inner(t, *args, **kwargs): + if not exists(t) or not exists(fn): + return None + return fn(t) + return inner + # tensor helpers def is_empty(t): @@ -1284,7 +1291,8 @@ class Attention(Module): pre_rmsnorm = True, gate_values = True, rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2 - rmsnorm_key = True + rmsnorm_key = True, + value_residual = True ): super().__init__() self.norm = RMSNorm(dim) if pre_rmsnorm else Identity() @@ -1323,6 +1331,14 @@ class Attention(Module): self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity() self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity() + # value residual + + self.to_learned_value_residual_mix = nn.Sequential( + nn.Linear(dim, heads), + Rearrange('b n h -> b h n 1'), + nn.Sigmoid() + ) if value_residual else None + def muon_parameters(self): # omit the queries and keys for now given what we learned from kimi 2 paper @@ -1337,6 +1353,7 @@ class Attention(Module): kv_cache = None, return_intermediates = False, rotary_pos_emb = None, + residual_values = None, # (b n h d) attend_fn: Callable | None = None ): tokens, inverse_packed_batch = pack_one(tokens, '* n d') @@ -1349,6 +1366,17 @@ class Attention(Module): q, k, v = map(self.split_heads, (q, k, v)) + # handle maybe value residual + + if exists(residual_values): + residual_values = rearrange(residual_values, '... n h d -> (...) h n d') + + assert exists(self.to_learned_value_residual_mix) + + learned_mix = self.to_learned_value_residual_mix(tokens) + + v = v.lerp(residual_values, learned_mix) + # qk rmsnorm q = self.q_heads_rmsnorm(q) @@ -1432,6 +1460,7 @@ class AxialSpaceTimeTransformer(Module): self, dim, depth, + attn_heads = 8, attn_dim_head = 64, attn_softclamp_value = 50., time_block_every = 4, @@ -1440,7 +1469,8 @@ class AxialSpaceTimeTransformer(Module): num_residual_streams = 1, num_special_spatial_tokens = 1, special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything) - final_norm = True + final_norm = True, + value_residual = True # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS ): super().__init__() assert depth >= time_block_every, f'depth must be at least {time_block_every}' @@ -1461,6 +1491,19 @@ class AxialSpaceTimeTransformer(Module): self.time_rotary = Rotary1D(attn_dim_head) + # project initial for value residuals + + self.value_residual = value_residual + + if value_residual: + dim_inner = attn_dim_head * attn_heads + + self.to_value_residual = nn.Sequential( + nn.RMSNorm(dim), + nn.Linear(dim, dim_inner, bias = False), + Rearrange('... (h d) -> ... h d', h = attn_heads) + ) + # transformer layers = [] @@ -1472,13 +1515,13 @@ class AxialSpaceTimeTransformer(Module): is_time_block = divisible_by(layer_index, time_block_every) is_time.append(is_time_block) - rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity() - rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity() + rearrange_to_attend = Rearrange('b t s ... -> b s t ...') if is_time_block else Identity() + rearrange_from_attend = Rearrange('b s t ... -> b t s ...') if is_time_block else Identity() layers.append(ModuleList([ rearrange_to_attend, rearrange_from_attend, - hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)), + hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, value_residual = value_residual, **attn_kwargs)), hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs)) ])) @@ -1529,7 +1572,6 @@ class AxialSpaceTimeTransformer(Module): time_attn_kv_caches = [] - if has_kv_cache: past_tokens, tokens = tokens[:, :-1], tokens[:, -1:] @@ -1547,6 +1589,13 @@ class AxialSpaceTimeTransformer(Module): rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset) + # value residual + + residual_values = None + + if self.value_residual: + residual_values = self.to_value_residual(tokens) + # normed attention inputs normed_time_attn_inputs = [] @@ -1570,6 +1619,10 @@ class AxialSpaceTimeTransformer(Module): maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None + # residual values + + layer_residual_values = maybe(pre_attn_rearrange)(residual_values) + # attention layer tokens, attn_intermediates = attn( @@ -1577,6 +1630,7 @@ class AxialSpaceTimeTransformer(Module): rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn, kv_cache = maybe_kv_cache, + residual_values = layer_residual_values, return_intermediates = True ) diff --git a/pyproject.toml b/pyproject.toml index b4ed20b..96c5e24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.1.12" +version = "0.1.14" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }