given the special attention patterns, attend function needs to be constructed before traversing the transformer layers
This commit is contained in:
parent
7cac3d28c5
commit
93f6738c9c
@ -372,7 +372,7 @@ def nonflex_block_causal_mask(seq_len, block_size, device = None):
|
||||
|
||||
def naive_attend(
|
||||
q, k, v,
|
||||
softclamp_value = 50.,
|
||||
softclamp_value = None,
|
||||
scale = None,
|
||||
causal = False,
|
||||
mask = None
|
||||
@ -431,9 +431,7 @@ class Attention(Module):
|
||||
dim_head = 64,
|
||||
query_heads = None,
|
||||
heads = 8,
|
||||
softclamp_value = 50.,
|
||||
pre_rmsnorm = True,
|
||||
causal = False
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
||||
@ -456,23 +454,23 @@ class Attention(Module):
|
||||
self.to_kv = LinearNoBias(dim, dim_kv_inner * 2)
|
||||
self.to_out = LinearNoBias(dim_kv_inner, dim)
|
||||
|
||||
# masking related
|
||||
|
||||
self.causal = causal
|
||||
|
||||
# stability related
|
||||
|
||||
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
|
||||
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
||||
|
||||
self.softclamp_value = softclamp_value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens, # (b n d)
|
||||
kv_cache = None,
|
||||
return_kv_cache = False,
|
||||
mask = None
|
||||
attend_fn: Callable | None = None,
|
||||
attend_kwargs: dict = dict(
|
||||
softclamp_value = None,
|
||||
causal = False,
|
||||
mask = None,
|
||||
scale = None
|
||||
)
|
||||
):
|
||||
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
||||
|
||||
@ -498,13 +496,9 @@ class Attention(Module):
|
||||
|
||||
# attention
|
||||
|
||||
out = naive_attend(
|
||||
q, k, v,
|
||||
softclamp_value = self.softclamp_value,
|
||||
scale = self.scale,
|
||||
causal = self.causal,
|
||||
mask = mask
|
||||
)
|
||||
attend_fn = default(attend_fn, naive_attend)
|
||||
|
||||
out = attend_fn(q, k, v, **attend_kwargs)
|
||||
|
||||
# merge heads
|
||||
|
||||
@ -560,6 +554,7 @@ class VideoTokenizer(Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
),
|
||||
attn_softclamp_value = 50.,
|
||||
ff_kwargs: dict = dict(),
|
||||
decoder_pos_mlp_depth = 2,
|
||||
channels = 3,
|
||||
@ -594,6 +589,10 @@ class VideoTokenizer(Module):
|
||||
Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
|
||||
)
|
||||
|
||||
# attention related
|
||||
|
||||
self.attn_softclamp_value = attn_softclamp_value
|
||||
|
||||
# encoder
|
||||
|
||||
encoder_layers = []
|
||||
@ -706,10 +705,14 @@ class VideoTokenizer(Module):
|
||||
|
||||
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
|
||||
|
||||
# attend hyper parameters
|
||||
|
||||
attend_kwargs = dict(softclamp_value = self.attn_softclamp_value)
|
||||
|
||||
# encoder
|
||||
|
||||
for attn, ff in self.encoder_layers:
|
||||
tokens = attn(tokens) + tokens
|
||||
tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
tokens = self.encoder_norm(tokens)
|
||||
@ -741,7 +744,7 @@ class VideoTokenizer(Module):
|
||||
# decoder attention
|
||||
|
||||
for attn, ff in self.decoder_layers:
|
||||
tokens = attn(tokens) + tokens
|
||||
tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
tokens = self.decoder_norm(tokens)
|
||||
@ -804,6 +807,7 @@ class DynamicsModel(Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
),
|
||||
attn_softclamp_value = 50.,
|
||||
ff_kwargs: dict = dict(),
|
||||
loss_weight_fn: Callable = ramp_weight,
|
||||
num_future_predictions = 8 # they do multi-token prediction of 8 steps forward
|
||||
@ -842,13 +846,20 @@ class DynamicsModel(Module):
|
||||
|
||||
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
|
||||
|
||||
# attention
|
||||
|
||||
self.attn_softclamp_value = attn_softclamp_value
|
||||
|
||||
# transformer
|
||||
|
||||
layers = []
|
||||
is_time = []
|
||||
|
||||
for i in range(depth):
|
||||
layer_index = i + 1
|
||||
|
||||
is_time_block = divisible_by(layer_index, time_block_every)
|
||||
is_time.append(is_time_block)
|
||||
|
||||
rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity()
|
||||
rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity()
|
||||
@ -856,11 +867,12 @@ class DynamicsModel(Module):
|
||||
layers.append(ModuleList([
|
||||
rearrange_to_attend,
|
||||
rearrange_from_attend,
|
||||
Attention(dim = dim, causal = is_time_block, **attn_kwargs),
|
||||
Attention(dim = dim, **attn_kwargs),
|
||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||
]))
|
||||
|
||||
self.layers = ModuleList(layers)
|
||||
self.is_time = is_time
|
||||
|
||||
# to prediction
|
||||
|
||||
@ -949,14 +961,28 @@ class DynamicsModel(Module):
|
||||
|
||||
# attention
|
||||
|
||||
for pre_attn_rearrange, post_attn_rearrange, attn, ff in self.layers:
|
||||
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
||||
|
||||
tokens = pre_attn_rearrange(tokens)
|
||||
|
||||
tokens = attn(tokens) + tokens
|
||||
# when is a axial time attention block, should be causal
|
||||
|
||||
attend_kwargs = dict()
|
||||
|
||||
if layer_is_time:
|
||||
attend_kwargs.update(
|
||||
softclamp_value = self.attn_softclamp_value,
|
||||
causal = True
|
||||
)
|
||||
|
||||
# attention layer
|
||||
|
||||
tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens
|
||||
|
||||
tokens = post_attn_rearrange(tokens)
|
||||
|
||||
# feedforward layer
|
||||
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
# unpack
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user