diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 290c409..50f5e90 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -372,7 +372,7 @@ def nonflex_block_causal_mask(seq_len, block_size, device = None): def naive_attend( q, k, v, - softclamp_value = 50., + softclamp_value = None, scale = None, causal = False, mask = None @@ -431,9 +431,7 @@ class Attention(Module): dim_head = 64, query_heads = None, heads = 8, - softclamp_value = 50., pre_rmsnorm = True, - causal = False ): super().__init__() self.norm = RMSNorm(dim) if pre_rmsnorm else Identity() @@ -456,23 +454,23 @@ class Attention(Module): self.to_kv = LinearNoBias(dim, dim_kv_inner * 2) self.to_out = LinearNoBias(dim_kv_inner, dim) - # masking related - - self.causal = causal - # stability related self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) - self.softclamp_value = softclamp_value - def forward( self, tokens, # (b n d) kv_cache = None, return_kv_cache = False, - mask = None + attend_fn: Callable | None = None, + attend_kwargs: dict = dict( + softclamp_value = None, + causal = False, + mask = None, + scale = None + ) ): tokens, inverse_packed_batch = pack_one(tokens, '* n d') @@ -498,13 +496,9 @@ class Attention(Module): # attention - out = naive_attend( - q, k, v, - softclamp_value = self.softclamp_value, - scale = self.scale, - causal = self.causal, - mask = mask - ) + attend_fn = default(attend_fn, naive_attend) + + out = attend_fn(q, k, v, **attend_kwargs) # merge heads @@ -560,6 +554,7 @@ class VideoTokenizer(Module): dim_head = 64, heads = 8, ), + attn_softclamp_value = 50., ff_kwargs: dict = dict(), decoder_pos_mlp_depth = 2, channels = 3, @@ -594,6 +589,10 @@ class VideoTokenizer(Module): Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size), ) + # attention related + + self.attn_softclamp_value = attn_softclamp_value + # encoder encoder_layers = [] @@ -706,10 +705,14 @@ class VideoTokenizer(Module): tokens, inverse_pack_time = pack_one(tokens, 'b * d') + # attend hyper parameters + + attend_kwargs = dict(softclamp_value = self.attn_softclamp_value) + # encoder for attn, ff in self.encoder_layers: - tokens = attn(tokens) + tokens + tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens tokens = ff(tokens) + tokens tokens = self.encoder_norm(tokens) @@ -741,7 +744,7 @@ class VideoTokenizer(Module): # decoder attention for attn, ff in self.decoder_layers: - tokens = attn(tokens) + tokens + tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens tokens = ff(tokens) + tokens tokens = self.decoder_norm(tokens) @@ -804,6 +807,7 @@ class DynamicsModel(Module): dim_head = 64, heads = 8, ), + attn_softclamp_value = 50., ff_kwargs: dict = dict(), loss_weight_fn: Callable = ramp_weight, num_future_predictions = 8 # they do multi-token prediction of 8 steps forward @@ -842,13 +846,20 @@ class DynamicsModel(Module): self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2) + # attention + + self.attn_softclamp_value = attn_softclamp_value + # transformer layers = [] + is_time = [] for i in range(depth): layer_index = i + 1 + is_time_block = divisible_by(layer_index, time_block_every) + is_time.append(is_time_block) rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity() rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity() @@ -856,11 +867,12 @@ class DynamicsModel(Module): layers.append(ModuleList([ rearrange_to_attend, rearrange_from_attend, - Attention(dim = dim, causal = is_time_block, **attn_kwargs), + Attention(dim = dim, **attn_kwargs), SwiGLUFeedforward(dim = dim, **ff_kwargs) ])) self.layers = ModuleList(layers) + self.is_time = is_time # to prediction @@ -949,14 +961,28 @@ class DynamicsModel(Module): # attention - for pre_attn_rearrange, post_attn_rearrange, attn, ff in self.layers: + for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time): tokens = pre_attn_rearrange(tokens) - tokens = attn(tokens) + tokens + # when is a axial time attention block, should be causal + + attend_kwargs = dict() + + if layer_is_time: + attend_kwargs.update( + softclamp_value = self.attn_softclamp_value, + causal = True + ) + + # attention layer + + tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens tokens = post_attn_rearrange(tokens) + # feedforward layer + tokens = ff(tokens) + tokens # unpack