add learned value residual
This commit is contained in:
parent
73029635fe
commit
5e75c4029d
@ -189,6 +189,13 @@ def sample_prob(prob):
|
||||
def is_power_two(num):
|
||||
return log2(num).is_integer()
|
||||
|
||||
def maybe(fn):
|
||||
def inner(t, *args, **kwargs):
|
||||
if not exists(t) or not exists(fn):
|
||||
return None
|
||||
return fn(t)
|
||||
return inner
|
||||
|
||||
# tensor helpers
|
||||
|
||||
def is_empty(t):
|
||||
@ -1284,7 +1291,8 @@ class Attention(Module):
|
||||
pre_rmsnorm = True,
|
||||
gate_values = True,
|
||||
rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2
|
||||
rmsnorm_key = True
|
||||
rmsnorm_key = True,
|
||||
value_residual = True
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
||||
@ -1323,6 +1331,14 @@ class Attention(Module):
|
||||
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity()
|
||||
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
|
||||
|
||||
# value residual
|
||||
|
||||
self.to_learned_value_residual_mix = nn.Sequential(
|
||||
nn.Linear(dim, heads),
|
||||
Rearrange('b n h -> b h n 1'),
|
||||
nn.Sigmoid()
|
||||
) if value_residual else None
|
||||
|
||||
def muon_parameters(self):
|
||||
# omit the queries and keys for now given what we learned from kimi 2 paper
|
||||
|
||||
@ -1337,6 +1353,7 @@ class Attention(Module):
|
||||
kv_cache = None,
|
||||
return_intermediates = False,
|
||||
rotary_pos_emb = None,
|
||||
residual_values = None, # (b n h d)
|
||||
attend_fn: Callable | None = None
|
||||
):
|
||||
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
||||
@ -1349,6 +1366,17 @@ class Attention(Module):
|
||||
|
||||
q, k, v = map(self.split_heads, (q, k, v))
|
||||
|
||||
# handle maybe value residual
|
||||
|
||||
if exists(residual_values):
|
||||
residual_values = rearrange(residual_values, '... n h d -> (...) h n d')
|
||||
|
||||
assert exists(self.to_learned_value_residual_mix)
|
||||
|
||||
learned_mix = self.to_learned_value_residual_mix(tokens)
|
||||
|
||||
v = v.lerp(residual_values, learned_mix)
|
||||
|
||||
# qk rmsnorm
|
||||
|
||||
q = self.q_heads_rmsnorm(q)
|
||||
@ -1432,6 +1460,7 @@ class AxialSpaceTimeTransformer(Module):
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
attn_heads = 8,
|
||||
attn_dim_head = 64,
|
||||
attn_softclamp_value = 50.,
|
||||
time_block_every = 4,
|
||||
@ -1440,7 +1469,8 @@ class AxialSpaceTimeTransformer(Module):
|
||||
num_residual_streams = 1,
|
||||
num_special_spatial_tokens = 1,
|
||||
special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
|
||||
final_norm = True
|
||||
final_norm = True,
|
||||
value_residual = True # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS
|
||||
):
|
||||
super().__init__()
|
||||
assert depth >= time_block_every, f'depth must be at least {time_block_every}'
|
||||
@ -1461,6 +1491,19 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
self.time_rotary = Rotary1D(attn_dim_head)
|
||||
|
||||
# project initial for value residuals
|
||||
|
||||
self.value_residual = value_residual
|
||||
|
||||
if value_residual:
|
||||
dim_inner = attn_dim_head * attn_heads
|
||||
|
||||
self.to_value_residual = nn.Sequential(
|
||||
nn.RMSNorm(dim),
|
||||
nn.Linear(dim, dim_inner, bias = False),
|
||||
Rearrange('... (h d) -> ... h d', h = attn_heads)
|
||||
)
|
||||
|
||||
# transformer
|
||||
|
||||
layers = []
|
||||
@ -1472,13 +1515,13 @@ class AxialSpaceTimeTransformer(Module):
|
||||
is_time_block = divisible_by(layer_index, time_block_every)
|
||||
is_time.append(is_time_block)
|
||||
|
||||
rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity()
|
||||
rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity()
|
||||
rearrange_to_attend = Rearrange('b t s ... -> b s t ...') if is_time_block else Identity()
|
||||
rearrange_from_attend = Rearrange('b s t ... -> b t s ...') if is_time_block else Identity()
|
||||
|
||||
layers.append(ModuleList([
|
||||
rearrange_to_attend,
|
||||
rearrange_from_attend,
|
||||
hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
|
||||
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, value_residual = value_residual, **attn_kwargs)),
|
||||
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
||||
]))
|
||||
|
||||
@ -1529,7 +1572,6 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
time_attn_kv_caches = []
|
||||
|
||||
|
||||
if has_kv_cache:
|
||||
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
||||
|
||||
@ -1547,6 +1589,13 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
||||
|
||||
# value residual
|
||||
|
||||
residual_values = None
|
||||
|
||||
if self.value_residual:
|
||||
residual_values = self.to_value_residual(tokens)
|
||||
|
||||
# normed attention inputs
|
||||
|
||||
normed_time_attn_inputs = []
|
||||
@ -1570,6 +1619,10 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None
|
||||
|
||||
# residual values
|
||||
|
||||
layer_residual_values = maybe(pre_attn_rearrange)(residual_values)
|
||||
|
||||
# attention layer
|
||||
|
||||
tokens, attn_intermediates = attn(
|
||||
@ -1577,6 +1630,7 @@ class AxialSpaceTimeTransformer(Module):
|
||||
rotary_pos_emb = layer_rotary_pos_emb,
|
||||
attend_fn = attend_fn,
|
||||
kv_cache = maybe_kv_cache,
|
||||
residual_values = layer_residual_values,
|
||||
return_intermediates = True
|
||||
)
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.1.12"
|
||||
version = "0.1.14"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user