From 11762699276a9e29cee822177a931bc3aebc5725 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 7 Oct 2025 07:41:02 -0700 Subject: [PATCH] correct signal levels when doing teacher forcing generation --- dreamer4/dreamer4.py | 71 +++++++++++++++++++++++++++----------------- pyproject.toml | 2 +- 2 files changed, 44 insertions(+), 29 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index c6f66c0..71d8208 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1252,7 +1252,8 @@ class DynamicsModel(Module): batch_size = 1, image_height = None, image_width = None, - return_decoded_video = None + return_decoded_video = None, + context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc - todo: handle this ): # (b t n d) | (b c t h w) assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2' @@ -1270,18 +1271,23 @@ class DynamicsModel(Module): latents = torch.empty((batch_size, 0, *latent_shape), device = self.device) + # while all the frames of the video (per latent) is not generated + while latents.shape[1] < time_steps: + curr_time_steps = latents.shape[1] noised_latent = torch.randn((batch_size, 1, *latent_shape), device = self.device) for step in range(num_steps): - signal_level = tensor(step * step_size, device = self.device) + signal_levels = torch.full((batch_size, 1), step * step_size, device = self.device) noised_latent_with_context, pack_context_shape = pack((latents, noised_latent), 'b * n d') + signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1) # todo - handle noising of past + pred = self.forward( latents = noised_latent_with_context, - signal_levels = signal_level, + signal_levels = signal_levels_with_context, step_sizes_log2 = step_size_log2, return_pred_only = True ) @@ -1291,7 +1297,7 @@ class DynamicsModel(Module): # derive flow, based on whether in x-space or not if self.pred_orig_latent: - times = self.get_times_from_signal_level(signal_level, noised_latent) + times = self.get_times_from_signal_level(signal_levels, noised_latent) flow = (pred - noised_latent) / (1. - times) else: flow = pred @@ -1350,51 +1356,60 @@ class DynamicsModel(Module): # shape related - if exists(signal_levels) and signal_levels.ndim == 0: - signal_levels = repeat(signal_levels, '-> b', b = batch) + if exists(signal_levels): + if signal_levels.ndim == 0: + signal_levels = repeat(signal_levels, '-> b', b = batch) - if exists(step_sizes_log2): - if step_sizes_log2.ndim == 0: - step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch) + if signal_levels.ndim == 1: + signal_levels = repeat(signal_levels, 'b -> b t', t = time) - if step_sizes_log2.ndim == 1: - step_sizes_log2 = repeat(step_sizes_log2, 'b -> b t', t = time) + if exists(step_sizes_log2) and step_sizes_log2.ndim == 0: + step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch) # flow related assert not (exists(signal_levels) ^ exists(step_sizes_log2)) - # if neither signal levels or step sizes passed in + is_inference = exists(signal_levels) + return_pred_only = is_inference + + # if neither signal levels or step sizes passed in, assume training # generate them randomly for training - no_shortcut_train = random() < self.prob_no_shortcut_train + if not is_inference: - if no_shortcut_train: - # if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min - # in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this + no_shortcut_train = random() < self.prob_no_shortcut_train - step_sizes_log2 = torch.zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1 - signal_levels = torch.randint(0, self.max_steps, (batch, time), device = device) - else: + if no_shortcut_train: + # if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min + # in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this - # now we follow eq (4) + step_sizes_log2 = torch.zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1 + signal_levels = torch.randint(0, self.max_steps, (batch, time), device = device) + else: - step_sizes_log2 = torch.randint(1, self.num_step_sizes_log2, (batch,), device = device) - num_step_sizes = 2 ** step_sizes_log2 + # now we follow eq (4) - signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes + step_sizes_log2 = torch.randint(1, self.num_step_sizes_log2, (batch,), device = device) + num_step_sizes = 2 ** step_sizes_log2 - # get the noise - - noise = torch.randn_like(latents) + signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes # times is from 0 to 1 times = self.get_times_from_signal_level(signal_levels, latents) - # noise from 0 as noise to 1 as data + if not is_inference: + # get the noise - noised_latents = noise.lerp(latents, times) + noise = torch.randn_like(latents) + + # noise from 0 as noise to 1 as data + + noised_latents = noise.lerp(latents, times) + + else: + noised_latents = latents # reinforcement learning related diff --git a/pyproject.toml b/pyproject.toml index 71814b4..3fa384a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.1" +version = "0.0.2" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }