diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index a337820..88dac6e 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -698,6 +698,7 @@ class DynamicsModel(Module): self, dim, dim_latent, + video_tokenizer: VideoTokenizer | None = None, num_signal_levels = 500, num_step_sizes = 32, num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction) @@ -714,6 +715,10 @@ class DynamicsModel(Module): ): super().__init__() + # can accept raw video if tokenizer is passed in + + self.video_tokenizer = video_tokenizer + # spatial and register tokens self.latents_to_spatial_tokens = Sequential( @@ -769,12 +774,34 @@ class DynamicsModel(Module): Linear(dim, dim_latent) ) + def parameter(self): + params = super().parameters() + + if not exists(self.video_tokenizer): + return params + + return list(set(params) - set(self.video_tokenizer.parameters())) + def forward( self, - latents, # (b t d) + *, + video = None, + latents = None, # (b t d) signal_levels = None, # (b t) step_sizes = None # (b t) ): + # handle video or latents + + assert exists(video) ^ exists(latents) + + if exists(video): + assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model' + + with torch.no_grad(): + self.video_tokenizer.eval() + latents = self.video_tokenizer(video, return_latents = True) + + # flow related assert not (exists(signal_levels) ^ exists(step_sizes)) diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 8fc606a..c14a6d5 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -3,26 +3,29 @@ param = pytest.mark.parametrize import torch @param('pred_orig_latent', (False, True)) -@param('gqa', (False, True)) +@param('grouped_query_attn', (False, True)) +@param('dynamics_with_video_input', (False, True)) def test_e2e( pred_orig_latent, - gqa + grouped_query_attn, + dynamics_with_video_input ): from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32) - x = torch.randn(2, 3, 4, 256, 256) + video = torch.randn(2, 3, 4, 256, 256) - loss = tokenizer(x) + loss = tokenizer(video) assert loss.numel() == 1 - latents = tokenizer(x, return_latents = True) + latents = tokenizer(video, return_latents = True) assert latents.shape[-1] == 32 - query_heads, heads = (16, 4) if gqa else (8, 8) + query_heads, heads = (16, 4) if grouped_query_attn else (8, 8) dynamics = DynamicsModel( 512, + video_tokenizer = tokenizer, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32, @@ -36,7 +39,12 @@ def test_e2e( signal_levels = torch.randint(0, 500, (2, 4)) step_sizes = torch.randint(0, 32, (2, 4)) - flow_loss = dynamics(latents, signal_levels = signal_levels, step_sizes = step_sizes) + if dynamics_with_video_input: + dynamics_input = dict(video = video) + else: + dynamics_input = dict(latents = latents) + + flow_loss = dynamics(**dynamics_input, signal_levels = signal_levels, step_sizes = step_sizes) assert flow_loss.numel() == 1 def test_symexp_two_hot():