diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index aef87f5..a42861a 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -440,6 +440,7 @@ class VideoTokenizer(Module): heads = 8, ), ff_kwargs: dict = dict(), + decoder_pos_mlp_depth = 2, channels = 3, per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue ): @@ -494,6 +495,15 @@ class VideoTokenizer(Module): # decoder + # parameterize the decoder positional embeddings for MAE style training so it can be resolution agnostic + + self.to_decoder_pos_emb = create_mlp( + dim_in = 2, + dim = dim * 2, + dim_out = dim, + depth = decoder_pos_mlp_depth, + ) + decoder_layers = [] for _ in range(decoder_depth): @@ -511,7 +521,8 @@ class VideoTokenizer(Module): return_latents = False, mask_patches = None ): - patch_size = self.patch_size + batch, time = video.shape[0], video.shape[2] + patch_size, device = self.patch_size, video.device *_, height, width = video.shape @@ -521,6 +532,10 @@ class VideoTokenizer(Module): tokens = self.patch_to_tokens(video) + # get some dimensions + + num_patch_height, num_patch_width, _ = tokens.shape[-3:] + # masking mask_patches = default(mask_patches, self.training) @@ -559,15 +574,29 @@ class VideoTokenizer(Module): # latent bottleneck + tokens = inverse_pack_time(tokens) + tokens = tokens[..., -1, :] + latents = self.encoded_to_latents(tokens) if return_latents: - latents = inverse_pack_time(latents) - return latents[..., -1, :] + return latents - tokens = self.latents_to_decoder(latents) + latent_tokens = self.latents_to_decoder(latents) - # decoder + # generate decoder positional embedding and concat the latent token + + spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device) + spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device) + + space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1) + + decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor) + decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time) + + tokens, _ = pack((decoder_pos_emb, latent_tokens), 'b * d') + + # decoder attention for attn, ff in self.decoder_layers: tokens = attn(tokens) + tokens diff --git a/pyproject.toml b/pyproject.toml index 5b56333..215d378 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ classifiers=[ dependencies = [ "accelerate", + "assoc-scan", "einx>=0.3.0", "einops>=0.8.1", "hl-gauss-pytorch",