allow for image pretraining on video tokenizer

This commit is contained in:
lucidrains 2025-12-04 10:34:15 -08:00
parent 9efe269688
commit 5bb027b386
3 changed files with 48 additions and 6 deletions

View File

@ -71,6 +71,8 @@ except ImportError:
LinearNoBias = partial(Linear, bias = False)
VideoTokenizerIntermediates = namedtuple('VideoTokenizerIntermediates', ('losses', 'recon'))
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions', 'state_pred'))
@ -1940,11 +1942,23 @@ class VideoTokenizer(Module):
def forward(
self,
video, # (b c t h w)
video_or_image, # (b c t h w) | (b c h w)
return_latents = False,
mask_patches = None,
return_all_losses = False
return_intermediates = False,
):
# handle image pretraining
is_image = video_or_image.ndim == 4
if is_image:
video = rearrange(video_or_image, 'b c h w -> b c 1 h w')
else:
video = video_or_image
# shapes
batch, _, time, height, width = video.shape
patch_size, device = self.patch_size, video.device
@ -2024,12 +2038,21 @@ class VideoTokenizer(Module):
space_decorr_loss * self.decorr_aux_loss_weight
)
if not return_all_losses:
if not return_intermediates:
return total_loss
losses = (recon_loss, lpips_loss, decorr_loss)
losses = TokenizerLosses(recon_loss, lpips_loss, time_decorr_loss, space_decorr_loss)
return total_loss, TokenizerLosses(*losses)
out = losses
# handle returning of reconstructed, and image pretraining
if is_image:
recon_video = rearrange(recon_video, 'b c 1 h w -> b c h w')
out = (losses, recon_video)
return total_loss, out
# dynamics model, axial space-time transformer

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.1.23"
version = "0.1.24"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -810,3 +810,22 @@ def test_epo():
fitness = torch.randn(16,)
dynamics.evolve_(fitness)
def test_images_to_video_tokenizer():
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel, AxialSpaceTimeTransformer
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256,
encoder_add_decor_aux_loss = True
)
images = torch.randn(2, 3, 256, 256)
loss, (losses, recon_images) = tokenizer(images, return_intermediates = True)
loss.backward()
assert images.shape == recon_images.shape