correct signal levels when doing teacher forcing generation

This commit is contained in:
lucidrains 2025-10-07 07:41:02 -07:00
parent c6bef85984
commit 1176269927
2 changed files with 44 additions and 29 deletions

View File

@ -1252,7 +1252,8 @@ class DynamicsModel(Module):
batch_size = 1,
image_height = None,
image_width = None,
return_decoded_video = None
return_decoded_video = None,
context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc - todo: handle this
): # (b t n d) | (b c t h w)
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
@ -1270,18 +1271,23 @@ class DynamicsModel(Module):
latents = torch.empty((batch_size, 0, *latent_shape), device = self.device)
# while all the frames of the video (per latent) is not generated
while latents.shape[1] < time_steps:
curr_time_steps = latents.shape[1]
noised_latent = torch.randn((batch_size, 1, *latent_shape), device = self.device)
for step in range(num_steps):
signal_level = tensor(step * step_size, device = self.device)
signal_levels = torch.full((batch_size, 1), step * step_size, device = self.device)
noised_latent_with_context, pack_context_shape = pack((latents, noised_latent), 'b * n d')
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1) # todo - handle noising of past
pred = self.forward(
latents = noised_latent_with_context,
signal_levels = signal_level,
signal_levels = signal_levels_with_context,
step_sizes_log2 = step_size_log2,
return_pred_only = True
)
@ -1291,7 +1297,7 @@ class DynamicsModel(Module):
# derive flow, based on whether in x-space or not
if self.pred_orig_latent:
times = self.get_times_from_signal_level(signal_level, noised_latent)
times = self.get_times_from_signal_level(signal_levels, noised_latent)
flow = (pred - noised_latent) / (1. - times)
else:
flow = pred
@ -1350,23 +1356,28 @@ class DynamicsModel(Module):
# shape related
if exists(signal_levels) and signal_levels.ndim == 0:
if exists(signal_levels):
if signal_levels.ndim == 0:
signal_levels = repeat(signal_levels, '-> b', b = batch)
if exists(step_sizes_log2):
if step_sizes_log2.ndim == 0:
step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch)
if signal_levels.ndim == 1:
signal_levels = repeat(signal_levels, 'b -> b t', t = time)
if step_sizes_log2.ndim == 1:
step_sizes_log2 = repeat(step_sizes_log2, 'b -> b t', t = time)
if exists(step_sizes_log2) and step_sizes_log2.ndim == 0:
step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch)
# flow related
assert not (exists(signal_levels) ^ exists(step_sizes_log2))
# if neither signal levels or step sizes passed in
is_inference = exists(signal_levels)
return_pred_only = is_inference
# if neither signal levels or step sizes passed in, assume training
# generate them randomly for training
if not is_inference:
no_shortcut_train = random() < self.prob_no_shortcut_train
if no_shortcut_train:
@ -1384,18 +1395,22 @@ class DynamicsModel(Module):
signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
# get the noise
noise = torch.randn_like(latents)
# times is from 0 to 1
times = self.get_times_from_signal_level(signal_levels, latents)
if not is_inference:
# get the noise
noise = torch.randn_like(latents)
# noise from 0 as noise to 1 as data
noised_latents = noise.lerp(latents, times)
else:
noised_latents = latents
# reinforcement learning related
agent_tokens = repeat(self.action_learned_embed, 'd -> b d', b = batch)

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.1"
version = "0.0.2"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }