allow for image pretraining on video tokenizer
This commit is contained in:
parent
9efe269688
commit
5bb027b386
@ -71,6 +71,8 @@ except ImportError:
|
||||
|
||||
LinearNoBias = partial(Linear, bias = False)
|
||||
|
||||
VideoTokenizerIntermediates = namedtuple('VideoTokenizerIntermediates', ('losses', 'recon'))
|
||||
|
||||
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
|
||||
|
||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions', 'state_pred'))
|
||||
@ -1940,11 +1942,23 @@ class VideoTokenizer(Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video, # (b c t h w)
|
||||
video_or_image, # (b c t h w) | (b c h w)
|
||||
return_latents = False,
|
||||
mask_patches = None,
|
||||
return_all_losses = False
|
||||
return_intermediates = False,
|
||||
):
|
||||
|
||||
# handle image pretraining
|
||||
|
||||
is_image = video_or_image.ndim == 4
|
||||
|
||||
if is_image:
|
||||
video = rearrange(video_or_image, 'b c h w -> b c 1 h w')
|
||||
else:
|
||||
video = video_or_image
|
||||
|
||||
# shapes
|
||||
|
||||
batch, _, time, height, width = video.shape
|
||||
patch_size, device = self.patch_size, video.device
|
||||
|
||||
@ -2024,12 +2038,21 @@ class VideoTokenizer(Module):
|
||||
space_decorr_loss * self.decorr_aux_loss_weight
|
||||
)
|
||||
|
||||
if not return_all_losses:
|
||||
if not return_intermediates:
|
||||
return total_loss
|
||||
|
||||
losses = (recon_loss, lpips_loss, decorr_loss)
|
||||
losses = TokenizerLosses(recon_loss, lpips_loss, time_decorr_loss, space_decorr_loss)
|
||||
|
||||
return total_loss, TokenizerLosses(*losses)
|
||||
out = losses
|
||||
|
||||
# handle returning of reconstructed, and image pretraining
|
||||
|
||||
if is_image:
|
||||
recon_video = rearrange(recon_video, 'b c 1 h w -> b c h w')
|
||||
|
||||
out = (losses, recon_video)
|
||||
|
||||
return total_loss, out
|
||||
|
||||
# dynamics model, axial space-time transformer
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.1.23"
|
||||
version = "0.1.24"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -810,3 +810,22 @@ def test_epo():
|
||||
|
||||
fitness = torch.randn(16,)
|
||||
dynamics.evolve_(fitness)
|
||||
|
||||
def test_images_to_video_tokenizer():
|
||||
import torch
|
||||
from dreamer4 import VideoTokenizer, DynamicsWorldModel, AxialSpaceTimeTransformer
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
dim = 512,
|
||||
dim_latent = 32,
|
||||
patch_size = 32,
|
||||
image_height = 256,
|
||||
image_width = 256,
|
||||
encoder_add_decor_aux_loss = True
|
||||
)
|
||||
|
||||
images = torch.randn(2, 3, 256, 256)
|
||||
loss, (losses, recon_images) = tokenizer(images, return_intermediates = True)
|
||||
loss.backward()
|
||||
|
||||
assert images.shape == recon_images.shape
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user