cleanup
This commit is contained in:
parent
895a867a66
commit
85eea216fd
@ -580,6 +580,14 @@ class VideoTokenizer(Module):
|
|||||||
self.decoder_layers = ModuleList(decoder_layers)
|
self.decoder_layers = ModuleList(decoder_layers)
|
||||||
self.decoder_norm = RMSNorm(dim)
|
self.decoder_norm = RMSNorm(dim)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def tokenize(
|
||||||
|
self,
|
||||||
|
video
|
||||||
|
):
|
||||||
|
self.eval()
|
||||||
|
return self.forward(video, return_latents = True)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
video, # (b c t h w)
|
video, # (b c t h w)
|
||||||
@ -797,9 +805,7 @@ class DynamicsModel(Module):
|
|||||||
if exists(video):
|
if exists(video):
|
||||||
assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
|
assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
|
||||||
|
|
||||||
with torch.no_grad():
|
latents = self.video_tokenizer.tokenize(video)
|
||||||
self.video_tokenizer.eval()
|
|
||||||
latents = self.video_tokenizer(video, return_latents = True)
|
|
||||||
|
|
||||||
# flow related
|
# flow related
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user