correct signal levels when doing teacher forcing generation

This commit is contained in:
lucidrains 2025-10-07 07:41:02 -07:00
parent c6bef85984
commit 1176269927
2 changed files with 44 additions and 29 deletions

View File

@ -1252,7 +1252,8 @@ class DynamicsModel(Module):
batch_size = 1,
image_height = None,
image_width = None,
return_decoded_video = None
return_decoded_video = None,
context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc - todo: handle this
): # (b t n d) | (b c t h w)
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
@ -1270,18 +1271,23 @@ class DynamicsModel(Module):
latents = torch.empty((batch_size, 0, *latent_shape), device = self.device)
# while all the frames of the video (per latent) is not generated
while latents.shape[1] < time_steps:
curr_time_steps = latents.shape[1]
noised_latent = torch.randn((batch_size, 1, *latent_shape), device = self.device)
for step in range(num_steps):
signal_level = tensor(step * step_size, device = self.device)
signal_levels = torch.full((batch_size, 1), step * step_size, device = self.device)
noised_latent_with_context, pack_context_shape = pack((latents, noised_latent), 'b * n d')
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1) # todo - handle noising of past
pred = self.forward(
latents = noised_latent_with_context,
signal_levels = signal_level,
signal_levels = signal_levels_with_context,
step_sizes_log2 = step_size_log2,
return_pred_only = True
)
@ -1291,7 +1297,7 @@ class DynamicsModel(Module):
# derive flow, based on whether in x-space or not
if self.pred_orig_latent:
times = self.get_times_from_signal_level(signal_level, noised_latent)
times = self.get_times_from_signal_level(signal_levels, noised_latent)
flow = (pred - noised_latent) / (1. - times)
else:
flow = pred
@ -1350,51 +1356,60 @@ class DynamicsModel(Module):
# shape related
if exists(signal_levels) and signal_levels.ndim == 0:
signal_levels = repeat(signal_levels, '-> b', b = batch)
if exists(signal_levels):
if signal_levels.ndim == 0:
signal_levels = repeat(signal_levels, '-> b', b = batch)
if exists(step_sizes_log2):
if step_sizes_log2.ndim == 0:
step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch)
if signal_levels.ndim == 1:
signal_levels = repeat(signal_levels, 'b -> b t', t = time)
if step_sizes_log2.ndim == 1:
step_sizes_log2 = repeat(step_sizes_log2, 'b -> b t', t = time)
if exists(step_sizes_log2) and step_sizes_log2.ndim == 0:
step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch)
# flow related
assert not (exists(signal_levels) ^ exists(step_sizes_log2))
# if neither signal levels or step sizes passed in
is_inference = exists(signal_levels)
return_pred_only = is_inference
# if neither signal levels or step sizes passed in, assume training
# generate them randomly for training
no_shortcut_train = random() < self.prob_no_shortcut_train
if not is_inference:
if no_shortcut_train:
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
# in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this
no_shortcut_train = random() < self.prob_no_shortcut_train
step_sizes_log2 = torch.zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
signal_levels = torch.randint(0, self.max_steps, (batch, time), device = device)
else:
if no_shortcut_train:
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
# in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this
# now we follow eq (4)
step_sizes_log2 = torch.zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
signal_levels = torch.randint(0, self.max_steps, (batch, time), device = device)
else:
step_sizes_log2 = torch.randint(1, self.num_step_sizes_log2, (batch,), device = device)
num_step_sizes = 2 ** step_sizes_log2
# now we follow eq (4)
signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
step_sizes_log2 = torch.randint(1, self.num_step_sizes_log2, (batch,), device = device)
num_step_sizes = 2 ** step_sizes_log2
# get the noise
noise = torch.randn_like(latents)
signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
# times is from 0 to 1
times = self.get_times_from_signal_level(signal_levels, latents)
# noise from 0 as noise to 1 as data
if not is_inference:
# get the noise
noised_latents = noise.lerp(latents, times)
noise = torch.randn_like(latents)
# noise from 0 as noise to 1 as data
noised_latents = noise.lerp(latents, times)
else:
noised_latents = latents
# reinforcement learning related

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.1"
version = "0.0.2"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }