add learned value residual

This commit is contained in:
lucidrains 2025-11-10 09:16:29 -08:00
parent 73029635fe
commit 5e75c4029d
2 changed files with 61 additions and 7 deletions

View File

@ -189,6 +189,13 @@ def sample_prob(prob):
def is_power_two(num):
return log2(num).is_integer()
def maybe(fn):
def inner(t, *args, **kwargs):
if not exists(t) or not exists(fn):
return None
return fn(t)
return inner
# tensor helpers
def is_empty(t):
@ -1284,7 +1291,8 @@ class Attention(Module):
pre_rmsnorm = True,
gate_values = True,
rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2
rmsnorm_key = True
rmsnorm_key = True,
value_residual = True
):
super().__init__()
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
@ -1323,6 +1331,14 @@ class Attention(Module):
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity()
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
# value residual
self.to_learned_value_residual_mix = nn.Sequential(
nn.Linear(dim, heads),
Rearrange('b n h -> b h n 1'),
nn.Sigmoid()
) if value_residual else None
def muon_parameters(self):
# omit the queries and keys for now given what we learned from kimi 2 paper
@ -1337,6 +1353,7 @@ class Attention(Module):
kv_cache = None,
return_intermediates = False,
rotary_pos_emb = None,
residual_values = None, # (b n h d)
attend_fn: Callable | None = None
):
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
@ -1349,6 +1366,17 @@ class Attention(Module):
q, k, v = map(self.split_heads, (q, k, v))
# handle maybe value residual
if exists(residual_values):
residual_values = rearrange(residual_values, '... n h d -> (...) h n d')
assert exists(self.to_learned_value_residual_mix)
learned_mix = self.to_learned_value_residual_mix(tokens)
v = v.lerp(residual_values, learned_mix)
# qk rmsnorm
q = self.q_heads_rmsnorm(q)
@ -1432,6 +1460,7 @@ class AxialSpaceTimeTransformer(Module):
self,
dim,
depth,
attn_heads = 8,
attn_dim_head = 64,
attn_softclamp_value = 50.,
time_block_every = 4,
@ -1440,7 +1469,8 @@ class AxialSpaceTimeTransformer(Module):
num_residual_streams = 1,
num_special_spatial_tokens = 1,
special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
final_norm = True
final_norm = True,
value_residual = True # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS
):
super().__init__()
assert depth >= time_block_every, f'depth must be at least {time_block_every}'
@ -1461,6 +1491,19 @@ class AxialSpaceTimeTransformer(Module):
self.time_rotary = Rotary1D(attn_dim_head)
# project initial for value residuals
self.value_residual = value_residual
if value_residual:
dim_inner = attn_dim_head * attn_heads
self.to_value_residual = nn.Sequential(
nn.RMSNorm(dim),
nn.Linear(dim, dim_inner, bias = False),
Rearrange('... (h d) -> ... h d', h = attn_heads)
)
# transformer
layers = []
@ -1472,13 +1515,13 @@ class AxialSpaceTimeTransformer(Module):
is_time_block = divisible_by(layer_index, time_block_every)
is_time.append(is_time_block)
rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity()
rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity()
rearrange_to_attend = Rearrange('b t s ... -> b s t ...') if is_time_block else Identity()
rearrange_from_attend = Rearrange('b s t ... -> b t s ...') if is_time_block else Identity()
layers.append(ModuleList([
rearrange_to_attend,
rearrange_from_attend,
hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, value_residual = value_residual, **attn_kwargs)),
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
]))
@ -1529,7 +1572,6 @@ class AxialSpaceTimeTransformer(Module):
time_attn_kv_caches = []
if has_kv_cache:
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
@ -1547,6 +1589,13 @@ class AxialSpaceTimeTransformer(Module):
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
# value residual
residual_values = None
if self.value_residual:
residual_values = self.to_value_residual(tokens)
# normed attention inputs
normed_time_attn_inputs = []
@ -1570,6 +1619,10 @@ class AxialSpaceTimeTransformer(Module):
maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None
# residual values
layer_residual_values = maybe(pre_attn_rearrange)(residual_values)
# attention layer
tokens, attn_intermediates = attn(
@ -1577,6 +1630,7 @@ class AxialSpaceTimeTransformer(Module):
rotary_pos_emb = layer_rotary_pos_emb,
attend_fn = attend_fn,
kv_cache = maybe_kv_cache,
residual_values = layer_residual_values,
return_intermediates = True
)

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.1.12"
version = "0.1.14"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }