diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index de0425c..4112820 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -20,8 +20,6 @@ from x_mlps_pytorch.ensemble import Ensemble from assoc_scan import AssocScan -from accelerate import Accelerator - # ein related # b - batch @@ -710,6 +708,8 @@ class VideoTokenizer(Module): dim, dim_latent, patch_size, + image_height = None, + image_width = None, num_latent_tokens = 4, encoder_depth = 4, decoder_depth = 4, @@ -795,6 +795,9 @@ class VideoTokenizer(Module): # decoder + self.image_height = image_height + self.image_width = image_width + # parameterize the decoder positional embeddings for MAE style training so it can be resolution agnostic self.to_decoder_pos_emb = create_mlp( @@ -825,6 +828,10 @@ class VideoTokenizer(Module): if self.has_lpips_loss: self.lpips = LPIPSLoss(lpips_loss_network) + @property + def device(self): + return self.zero.device + @torch.no_grad() def tokenize( self, @@ -833,6 +840,102 @@ class VideoTokenizer(Module): self.eval() return self.forward(video, return_latents = True) + def get_rotary_pos_emb( + self, + time, + num_patch_height, + num_patch_width + ): + device = self.device + + positions = stack(torch.meshgrid( + arange(time, device = device), + arange(num_patch_height, device = device), + arange(num_patch_width, device = device) + ), dim = -1) + + positions = rearrange(positions, 't h w p -> t (h w) p') + + # give the latents an out of bounds position and assume the network will figure it out + + positions = pad_at_dim(positions, (0, self.num_latent_tokens), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated + + positions = rearrange(positions, 't hw p -> (t hw) p') + + return self.spacetime_rotary(positions) + + def decode( + self, + latents, # (b t n d) + height = None, + width = None, + rotary_pos_emb = None + ): # (b c t h w) + + height = default(height, self.image_height) + width = default(width, self.image_width) + + assert exists(height) and exists(width), f'image height and width need to be passed in when decoding latents' + + batch, time, device = *latents.shape[:2], latents.device + + use_flex = latents.is_cuda and exists(flex_attention) + + num_patch_height = height // self.patch_size + num_patch_width = width // self.patch_size + + if not exists(rotary_pos_emb): + rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width) + + # latents to tokens + + latent_tokens = self.latents_to_decoder(latents) + + # generate decoder positional embedding and concat the latent token + + spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device) + spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device) + + space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1) + + decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor) + decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time) + + tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d') + + # pack time + + tokens, inverse_pack_time = pack_one(tokens, 'b * d') + + seq_len = tokens.shape[-2] + + # decoder attend + + decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = True) + + # decoder attention + + for attn, ff in self.decoder_layers: + tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens + + tokens = ff(tokens) + tokens + + tokens = self.decoder_norm(tokens) + + # unpack time + + tokens = inverse_pack_time(tokens) + + # unpack latents + + tokens, latent_tokens = unpack(tokens, packed_latent_shape, 'b t * d') + + # project back to patches + + recon_video = self.tokens_to_patch(tokens) + + return recon_video + def forward( self, video, # (b c t h w) @@ -855,21 +958,7 @@ class VideoTokenizer(Module): # rotary positions - positions = stack(torch.meshgrid( - arange(time, device = device), - arange(num_patch_height, device = device), - arange(num_patch_width, device = device) - ), dim = -1) - - positions = rearrange(positions, 't h w p -> t (h w) p') - - # give the latents an out of bounds position and assume the network will figure it out - - positions = pad_at_dim(positions, (0, self.num_latent_tokens), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated - - positions = rearrange(positions, 't hw p -> (t hw) p') - - rotary_pos_emb = self.spacetime_rotary(positions) + rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width) # masking @@ -941,48 +1030,7 @@ class VideoTokenizer(Module): if return_latents: return latents - latent_tokens = self.latents_to_decoder(latents) - - # generate decoder positional embedding and concat the latent token - - spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device) - spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device) - - space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1) - - decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor) - decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time) - - tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d') - - # pack time - - tokens, inverse_pack_time = pack_one(tokens, 'b * d') - - # decoder attend - - decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = True) - - # decoder attention - - for attn, ff in self.decoder_layers: - tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens - - tokens = ff(tokens) + tokens - - tokens = self.decoder_norm(tokens) - - # unpack time - - tokens = inverse_pack_time(tokens) - - # unpack latents - - tokens, latent_tokens = unpack(tokens, packed_latent_shape, 'b t * d') - - # project back to patches - - recon_video = self.tokens_to_patch(tokens) + recon_video = self.decode(latents, height = height, width = width, rotary_pos_emb = rotary_pos_emb) # losses @@ -1180,6 +1228,18 @@ class DynamicsModel(Module): return list(set(params) - set(self.video_tokenizer.parameters())) + def generate( + self, + num_frames, + num_steps = 4, + image_height = None, + image_width = None + ): # (b t n d) | (b c t h w) + + assert log(num_steps).is_integer(), f'number of steps must be a power of 2' + + raise NotImplementedError + def forward( self, *, diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 696af1c..888885c 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -40,6 +40,9 @@ def test_e2e( latents = tokenizer(video, return_latents = True) assert latents.shape[-1] == 16 + recon = tokenizer.decode(latents, 256, 256) + assert recon.shape == video.shape + query_heads, heads = (16, 4) if grouped_query_attn else (8, 8) dynamics = DynamicsModel(