diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index e2e6e37..b058962 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -281,34 +281,49 @@ class GoldenGateRoPENd(Module): theta = reduce(freqs * positions, 'h n f p -> h n f', 'sum') - return theta + return cat((theta, theta), dim = -1) + +class Rotary1D(Module): + def __init__( + self, + dim_head, + theta = 10000. + ): + super().__init__() + inv_freq = 1.0 / (theta ** (arange(0, dim_head, 2).float() / dim_head)) + self.register_buffer('inv_freq', inv_freq) + + def forward( + self, + seq_len + ): + device, dtype = self.inv_freq.device, self.inv_freq.dtype + + t = torch.arange(seq_len, device = device).type(dtype) + freqs = einsum(t, self.inv_freq, 'i, j -> i j') + + return cat((freqs, freqs), dim = -1) + def apply_rotations( - theta, # (h n f) - qk + rotations, # (h n d) | (n d) + t # (b h n d) ): - rotary_heads = theta.shape[0] - heads, dtype = qk.shape[1], qk + heads, dtype = t.shape[1], t.dtype + t = t.float() # handle gqa for rotary - if heads < rotary_heads: + if rotations.ndim > 2 and heads < rotations.shape[0]: assert divisible_by(heads, rotary_heads) groups = heads // rotary_heads - theta = repeat(theta, 'h ... -> (h g) ...', g = groups) + rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups) - x, y = rearrange(qk.float(), '... (split d) -> split ... d', split = 2) # (b, h, n, f) + x1, x2 = t.chunk(2, dim = -1) + rotated_half_t = cat((-x2, x1), dim = -1) - # apply rotations - - cos_theta = torch.cos(theta) - sin_theta = torch.sin(theta) - - x_out = x * cos_theta - y * sin_theta - y_out = x * sin_theta + y * cos_theta - - out = rearrange([x_out, y_out], 'split ... d -> ... (split d)') - return out.type_as(dtype) + rotated = t * rotations.cos() + rotated_half_t * rotations.sin() + return rotated.type(dtype) # multi-head rmsnorm @@ -862,9 +877,9 @@ class DynamicsModel(Module): pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space) time_block_every = 4, # every 4th block is time attn_kwargs: dict = dict( - dim_head = 64, heads = 8, ), + attn_dim_head = 64, attn_softclamp_value = 50., ff_kwargs: dict = dict(), loss_weight_fn: Callable = ramp_weight, @@ -908,6 +923,10 @@ class DynamicsModel(Module): self.attn_softclamp_value = attn_softclamp_value + # time rotary embedding + + self.time_rotary = Rotary1D(attn_dim_head) + # transformer layers = [] @@ -925,7 +944,7 @@ class DynamicsModel(Module): layers.append(ModuleList([ rearrange_to_attend, rearrange_from_attend, - Attention(dim = dim, **attn_kwargs), + Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs), SwiGLUFeedforward(dim = dim, **ff_kwargs) ])) @@ -964,6 +983,8 @@ class DynamicsModel(Module): latents = self.video_tokenizer.tokenize(video) + time = latents.shape[1] + # flow related assert not (exists(signal_levels) ^ exists(step_sizes)) @@ -1025,6 +1046,10 @@ class DynamicsModel(Module): time_attend = partial(naive_attend, causal = True, **attend_kwargs) + # rotary + + rotary_pos_emb = self.time_rotary(time) + # attention for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time): @@ -1035,9 +1060,11 @@ class DynamicsModel(Module): attend_fn = time_attend if layer_is_time else space_attend + layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None + # attention layer - tokens = attn(tokens, attend_fn = attend_fn) + tokens + tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens tokens = post_attn_rearrange(tokens)