generating video with raw teacher forcing
This commit is contained in:
parent
83ba9a285a
commit
c6bef85984
@ -1220,6 +1220,22 @@ class DynamicsModel(Module):
|
||||
|
||||
self.register_buffer('zero', tensor(0.), persistent = False)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.zero.device
|
||||
|
||||
def get_times_from_signal_level(
|
||||
self,
|
||||
signal_levels,
|
||||
align_dims_left_to = None
|
||||
):
|
||||
times = signal_levels.float() / self.max_steps
|
||||
|
||||
if not exists(align_dims_left_to):
|
||||
return times
|
||||
|
||||
return align_dims_left(times, align_dims_left_to)
|
||||
|
||||
def parameter(self):
|
||||
params = super().parameters()
|
||||
|
||||
@ -1228,25 +1244,87 @@ class DynamicsModel(Module):
|
||||
|
||||
return list(set(params) - set(self.video_tokenizer.parameters()))
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
num_frames,
|
||||
time_steps,
|
||||
num_steps = 4,
|
||||
batch_size = 1,
|
||||
image_height = None,
|
||||
image_width = None
|
||||
image_width = None,
|
||||
return_decoded_video = None
|
||||
): # (b t n d) | (b c t h w)
|
||||
|
||||
assert log(num_steps).is_integer(), f'number of steps must be a power of 2'
|
||||
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
||||
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
||||
|
||||
raise NotImplementedError
|
||||
latent_shape = self.latent_shape
|
||||
|
||||
# derive step size
|
||||
|
||||
step_size = self.max_steps // num_steps
|
||||
step_size_log2 = tensor(log2(step_size), dtype = torch.long, device = self.device)
|
||||
|
||||
# denoising
|
||||
# teacher forcing to start with
|
||||
|
||||
latents = torch.empty((batch_size, 0, *latent_shape), device = self.device)
|
||||
|
||||
while latents.shape[1] < time_steps:
|
||||
|
||||
noised_latent = torch.randn((batch_size, 1, *latent_shape), device = self.device)
|
||||
|
||||
for step in range(num_steps):
|
||||
signal_level = tensor(step * step_size, device = self.device)
|
||||
|
||||
noised_latent_with_context, pack_context_shape = pack((latents, noised_latent), 'b * n d')
|
||||
|
||||
pred = self.forward(
|
||||
latents = noised_latent_with_context,
|
||||
signal_levels = signal_level,
|
||||
step_sizes_log2 = step_size_log2,
|
||||
return_pred_only = True
|
||||
)
|
||||
|
||||
_, pred = unpack(pred, pack_context_shape, 'b * n d')
|
||||
|
||||
# derive flow, based on whether in x-space or not
|
||||
|
||||
if self.pred_orig_latent:
|
||||
times = self.get_times_from_signal_level(signal_level, noised_latent)
|
||||
flow = (pred - noised_latent) / (1. - times)
|
||||
else:
|
||||
flow = pred
|
||||
|
||||
# denoise
|
||||
|
||||
noised_latent += flow * (step_size / self.max_steps)
|
||||
|
||||
latents = cat((latents, noised_latent), dim = 1)
|
||||
|
||||
# returning video
|
||||
|
||||
has_tokenizer = exists(self.video_tokenizer)
|
||||
return_decoded_video = default(return_decoded_video, has_tokenizer)
|
||||
|
||||
if not return_decoded_video:
|
||||
return denoised_latents
|
||||
|
||||
generated_video = self.video_tokenizer.decode(
|
||||
latents,
|
||||
height = image_height,
|
||||
width = image_width
|
||||
)
|
||||
|
||||
return generated_video
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
video = None,
|
||||
latents = None, # (b t n d) | (b t d)
|
||||
signal_levels = None, # (b t)
|
||||
step_sizes_log2 = None, # (b)
|
||||
signal_levels = None, # () | (b) | (b t)
|
||||
step_sizes_log2 = None, # () | (b)
|
||||
tasks = None, # (b)
|
||||
rewards = None, # (b t)
|
||||
return_pred_only = False,
|
||||
@ -1270,6 +1348,18 @@ class DynamicsModel(Module):
|
||||
|
||||
batch, time, device = *latents.shape[:2], latents.device
|
||||
|
||||
# shape related
|
||||
|
||||
if exists(signal_levels) and signal_levels.ndim == 0:
|
||||
signal_levels = repeat(signal_levels, '-> b', b = batch)
|
||||
|
||||
if exists(step_sizes_log2):
|
||||
if step_sizes_log2.ndim == 0:
|
||||
step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch)
|
||||
|
||||
if step_sizes_log2.ndim == 1:
|
||||
step_sizes_log2 = repeat(step_sizes_log2, 'b -> b t', t = time)
|
||||
|
||||
# flow related
|
||||
|
||||
assert not (exists(signal_levels) ^ exists(step_sizes_log2))
|
||||
@ -1300,11 +1390,7 @@ class DynamicsModel(Module):
|
||||
|
||||
# times is from 0 to 1
|
||||
|
||||
def get_times_from_signal_level(signal_levels):
|
||||
times = signal_levels.float() / self.max_steps
|
||||
return align_dims_left(times, latents)
|
||||
|
||||
times = get_times_from_signal_level(signal_levels)
|
||||
times = self.get_times_from_signal_level(signal_levels, latents)
|
||||
|
||||
# noise from 0 as noise to 1 as data
|
||||
|
||||
@ -1468,7 +1554,7 @@ class DynamicsModel(Module):
|
||||
if is_v_space_pred:
|
||||
first_step_pred_flow = first_step_pred
|
||||
else:
|
||||
first_times = get_times_from_signal_level(signal_levels)
|
||||
first_times = self.get_times_from_signal_level(signal_levels, noised_latents)
|
||||
first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times)
|
||||
|
||||
# take a half step
|
||||
@ -1485,7 +1571,7 @@ class DynamicsModel(Module):
|
||||
if is_v_space_pred:
|
||||
second_step_pred_flow = second_step_pred
|
||||
else:
|
||||
second_times = get_times_from_signal_level(signal_levels_plus_half_step)
|
||||
second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised_latent)
|
||||
second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times)
|
||||
|
||||
# pred target is sg(b' + b'') / 2
|
||||
|
||||
@ -88,6 +88,17 @@ def test_e2e(
|
||||
|
||||
assert flow_loss.numel() == 1
|
||||
|
||||
# generating
|
||||
|
||||
generated_video = dynamics.generate(
|
||||
time_steps = 10,
|
||||
image_height = 128,
|
||||
image_width = 128,
|
||||
batch_size = 2
|
||||
)
|
||||
|
||||
assert generated_video.shape == (2, 3, 10, 128, 128)
|
||||
|
||||
# rl
|
||||
|
||||
rewards = torch.randn((2, 4)) * 100.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user