This commit is contained in:
lucidrains 2025-10-04 06:59:09 -07:00
parent 895a867a66
commit 85eea216fd

View File

@ -580,6 +580,14 @@ class VideoTokenizer(Module):
self.decoder_layers = ModuleList(decoder_layers)
self.decoder_norm = RMSNorm(dim)
@torch.no_grad()
def tokenize(
self,
video
):
self.eval()
return self.forward(video, return_latents = True)
def forward(
self,
video, # (b c t h w)
@ -797,9 +805,7 @@ class DynamicsModel(Module):
if exists(video):
assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
with torch.no_grad():
self.video_tokenizer.eval()
latents = self.video_tokenizer(video, return_latents = True)
latents = self.video_tokenizer.tokenize(video)
# flow related