generating video with raw teacher forcing

This commit is contained in:
lucidrains 2025-10-07 07:22:57 -07:00
parent 83ba9a285a
commit c6bef85984
2 changed files with 110 additions and 13 deletions

View File

@ -1220,6 +1220,22 @@ class DynamicsModel(Module):
self.register_buffer('zero', tensor(0.), persistent = False)
@property
def device(self):
return self.zero.device
def get_times_from_signal_level(
self,
signal_levels,
align_dims_left_to = None
):
times = signal_levels.float() / self.max_steps
if not exists(align_dims_left_to):
return times
return align_dims_left(times, align_dims_left_to)
def parameter(self):
params = super().parameters()
@ -1228,25 +1244,87 @@ class DynamicsModel(Module):
return list(set(params) - set(self.video_tokenizer.parameters()))
@torch.no_grad()
def generate(
self,
num_frames,
time_steps,
num_steps = 4,
batch_size = 1,
image_height = None,
image_width = None
image_width = None,
return_decoded_video = None
): # (b t n d) | (b c t h w)
assert log(num_steps).is_integer(), f'number of steps must be a power of 2'
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
raise NotImplementedError
latent_shape = self.latent_shape
# derive step size
step_size = self.max_steps // num_steps
step_size_log2 = tensor(log2(step_size), dtype = torch.long, device = self.device)
# denoising
# teacher forcing to start with
latents = torch.empty((batch_size, 0, *latent_shape), device = self.device)
while latents.shape[1] < time_steps:
noised_latent = torch.randn((batch_size, 1, *latent_shape), device = self.device)
for step in range(num_steps):
signal_level = tensor(step * step_size, device = self.device)
noised_latent_with_context, pack_context_shape = pack((latents, noised_latent), 'b * n d')
pred = self.forward(
latents = noised_latent_with_context,
signal_levels = signal_level,
step_sizes_log2 = step_size_log2,
return_pred_only = True
)
_, pred = unpack(pred, pack_context_shape, 'b * n d')
# derive flow, based on whether in x-space or not
if self.pred_orig_latent:
times = self.get_times_from_signal_level(signal_level, noised_latent)
flow = (pred - noised_latent) / (1. - times)
else:
flow = pred
# denoise
noised_latent += flow * (step_size / self.max_steps)
latents = cat((latents, noised_latent), dim = 1)
# returning video
has_tokenizer = exists(self.video_tokenizer)
return_decoded_video = default(return_decoded_video, has_tokenizer)
if not return_decoded_video:
return denoised_latents
generated_video = self.video_tokenizer.decode(
latents,
height = image_height,
width = image_width
)
return generated_video
def forward(
self,
*,
video = None,
latents = None, # (b t n d) | (b t d)
signal_levels = None, # (b t)
step_sizes_log2 = None, # (b)
signal_levels = None, # () | (b) | (b t)
step_sizes_log2 = None, # () | (b)
tasks = None, # (b)
rewards = None, # (b t)
return_pred_only = False,
@ -1270,6 +1348,18 @@ class DynamicsModel(Module):
batch, time, device = *latents.shape[:2], latents.device
# shape related
if exists(signal_levels) and signal_levels.ndim == 0:
signal_levels = repeat(signal_levels, '-> b', b = batch)
if exists(step_sizes_log2):
if step_sizes_log2.ndim == 0:
step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch)
if step_sizes_log2.ndim == 1:
step_sizes_log2 = repeat(step_sizes_log2, 'b -> b t', t = time)
# flow related
assert not (exists(signal_levels) ^ exists(step_sizes_log2))
@ -1300,11 +1390,7 @@ class DynamicsModel(Module):
# times is from 0 to 1
def get_times_from_signal_level(signal_levels):
times = signal_levels.float() / self.max_steps
return align_dims_left(times, latents)
times = get_times_from_signal_level(signal_levels)
times = self.get_times_from_signal_level(signal_levels, latents)
# noise from 0 as noise to 1 as data
@ -1468,7 +1554,7 @@ class DynamicsModel(Module):
if is_v_space_pred:
first_step_pred_flow = first_step_pred
else:
first_times = get_times_from_signal_level(signal_levels)
first_times = self.get_times_from_signal_level(signal_levels, noised_latents)
first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times)
# take a half step
@ -1485,7 +1571,7 @@ class DynamicsModel(Module):
if is_v_space_pred:
second_step_pred_flow = second_step_pred
else:
second_times = get_times_from_signal_level(signal_levels_plus_half_step)
second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised_latent)
second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times)
# pred target is sg(b' + b'') / 2

View File

@ -88,6 +88,17 @@ def test_e2e(
assert flow_loss.numel() == 1
# generating
generated_video = dynamics.generate(
time_steps = 10,
image_height = 128,
image_width = 128,
batch_size = 2
)
assert generated_video.shape == (2, 3, 10, 128, 128)
# rl
rewards = torch.randn((2, 4)) * 100.