From 85eea216fdb59002b38fbd0b204fad86c790b248 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 4 Oct 2025 06:59:09 -0700 Subject: [PATCH] cleanup --- dreamer4/dreamer4.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 88dac6e..c938db8 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -580,6 +580,14 @@ class VideoTokenizer(Module): self.decoder_layers = ModuleList(decoder_layers) self.decoder_norm = RMSNorm(dim) + @torch.no_grad() + def tokenize( + self, + video + ): + self.eval() + return self.forward(video, return_latents = True) + def forward( self, video, # (b c t h w) @@ -797,9 +805,7 @@ class DynamicsModel(Module): if exists(video): assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model' - with torch.no_grad(): - self.video_tokenizer.eval() - latents = self.video_tokenizer(video, return_latents = True) + latents = self.video_tokenizer.tokenize(video) # flow related