diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 351db4e..8435d70 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1253,7 +1253,7 @@ class DynamicsModel(Module): image_height = None, image_width = None, return_decoded_video = None, - context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc - todo: handle this + context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc ): # (b t n d) | (b c t h w) assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2' @@ -1280,9 +1280,11 @@ class DynamicsModel(Module): for step in range(num_steps): signal_levels = torch.full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device) - noised_latent_with_context, pack_context_shape = pack((latents, noised_latent), 'b * n d') + noised_context = latents.lerp(torch.randn_like(latents), context_signal_noise) # the paragraph after eq (8) - signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1) # todo - handle noising of past + noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d') + + signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1) pred = self.forward( latents = noised_latent_with_context, diff --git a/dreamer4/trainers.py b/dreamer4/trainers.py index b15ff8a..c450ee5 100644 --- a/dreamer4/trainers.py +++ b/dreamer4/trainers.py @@ -1,3 +1,17 @@ import torch +from torch.nn import Module from accelerate import Accelerator + +from dreamer4.dreamer4 import ( + VideoTokenizer, + DynamicsModel +) + +class VideoTokenizerTrainer(Module): + def __init__( + self, + model: VideoTokenizer + ): + super().__init__() + raise NotImplementedError diff --git a/pyproject.toml b/pyproject.toml index ede90a7..8e92774 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.3" +version = "0.0.4" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }