cleanup
This commit is contained in:
parent
895a867a66
commit
85eea216fd
@ -580,6 +580,14 @@ class VideoTokenizer(Module):
|
||||
self.decoder_layers = ModuleList(decoder_layers)
|
||||
self.decoder_norm = RMSNorm(dim)
|
||||
|
||||
@torch.no_grad()
|
||||
def tokenize(
|
||||
self,
|
||||
video
|
||||
):
|
||||
self.eval()
|
||||
return self.forward(video, return_latents = True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video, # (b c t h w)
|
||||
@ -797,9 +805,7 @@ class DynamicsModel(Module):
|
||||
if exists(video):
|
||||
assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
|
||||
|
||||
with torch.no_grad():
|
||||
self.video_tokenizer.eval()
|
||||
latents = self.video_tokenizer(video, return_latents = True)
|
||||
latents = self.video_tokenizer.tokenize(video)
|
||||
|
||||
# flow related
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user