add the noising of the latent context during generation, technique i think was from EPFL, or perhaps some google group that built on top of EPFL work

This commit is contained in:
lucidrains 2025-10-07 09:37:37 -07:00
parent 36ccb08500
commit 0fdb67bafa
3 changed files with 20 additions and 4 deletions

View File

@ -1253,7 +1253,7 @@ class DynamicsModel(Module):
image_height = None,
image_width = None,
return_decoded_video = None,
context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc - todo: handle this
context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
): # (b t n d) | (b c t h w)
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
@ -1280,9 +1280,11 @@ class DynamicsModel(Module):
for step in range(num_steps):
signal_levels = torch.full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
noised_latent_with_context, pack_context_shape = pack((latents, noised_latent), 'b * n d')
noised_context = latents.lerp(torch.randn_like(latents), context_signal_noise) # the paragraph after eq (8)
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1) # todo - handle noising of past
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
pred = self.forward(
latents = noised_latent_with_context,

View File

@ -1,3 +1,17 @@
import torch
from torch.nn import Module
from accelerate import Accelerator
from dreamer4.dreamer4 import (
VideoTokenizer,
DynamicsModel
)
class VideoTokenizerTrainer(Module):
def __init__(
self,
model: VideoTokenizer
):
super().__init__()
raise NotImplementedError

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.3"
version = "0.0.4"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }