add the noising of the latent context during generation, technique i think was from EPFL, or perhaps some google group that built on top of EPFL work
This commit is contained in:
parent
36ccb08500
commit
0fdb67bafa
@ -1253,7 +1253,7 @@ class DynamicsModel(Module):
|
||||
image_height = None,
|
||||
image_width = None,
|
||||
return_decoded_video = None,
|
||||
context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc - todo: handle this
|
||||
context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
||||
): # (b t n d) | (b c t h w)
|
||||
|
||||
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
||||
@ -1280,9 +1280,11 @@ class DynamicsModel(Module):
|
||||
for step in range(num_steps):
|
||||
signal_levels = torch.full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
||||
|
||||
noised_latent_with_context, pack_context_shape = pack((latents, noised_latent), 'b * n d')
|
||||
noised_context = latents.lerp(torch.randn_like(latents), context_signal_noise) # the paragraph after eq (8)
|
||||
|
||||
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1) # todo - handle noising of past
|
||||
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
||||
|
||||
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
||||
|
||||
pred = self.forward(
|
||||
latents = noised_latent_with_context,
|
||||
|
||||
@ -1,3 +1,17 @@
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
from dreamer4.dreamer4 import (
|
||||
VideoTokenizer,
|
||||
DynamicsModel
|
||||
)
|
||||
|
||||
class VideoTokenizerTrainer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: VideoTokenizer
|
||||
):
|
||||
super().__init__()
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.3"
|
||||
version = "0.0.4"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user