add the noising of the latent context during generation, technique i think was from EPFL, or perhaps some google group that built on top of EPFL work
This commit is contained in:
parent
36ccb08500
commit
0fdb67bafa
@ -1253,7 +1253,7 @@ class DynamicsModel(Module):
|
|||||||
image_height = None,
|
image_height = None,
|
||||||
image_width = None,
|
image_width = None,
|
||||||
return_decoded_video = None,
|
return_decoded_video = None,
|
||||||
context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc - todo: handle this
|
context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
||||||
): # (b t n d) | (b c t h w)
|
): # (b t n d) | (b c t h w)
|
||||||
|
|
||||||
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
||||||
@ -1280,9 +1280,11 @@ class DynamicsModel(Module):
|
|||||||
for step in range(num_steps):
|
for step in range(num_steps):
|
||||||
signal_levels = torch.full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
signal_levels = torch.full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
||||||
|
|
||||||
noised_latent_with_context, pack_context_shape = pack((latents, noised_latent), 'b * n d')
|
noised_context = latents.lerp(torch.randn_like(latents), context_signal_noise) # the paragraph after eq (8)
|
||||||
|
|
||||||
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1) # todo - handle noising of past
|
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
||||||
|
|
||||||
|
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
||||||
|
|
||||||
pred = self.forward(
|
pred = self.forward(
|
||||||
latents = noised_latent_with_context,
|
latents = noised_latent_with_context,
|
||||||
|
|||||||
@ -1,3 +1,17 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
from dreamer4.dreamer4 import (
|
||||||
|
VideoTokenizer,
|
||||||
|
DynamicsModel
|
||||||
|
)
|
||||||
|
|
||||||
|
class VideoTokenizerTrainer(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: VideoTokenizer
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.3"
|
version = "0.0.4"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user