correct signal levels when doing teacher forcing generation
This commit is contained in:
parent
c6bef85984
commit
1176269927
@ -1252,7 +1252,8 @@ class DynamicsModel(Module):
|
||||
batch_size = 1,
|
||||
image_height = None,
|
||||
image_width = None,
|
||||
return_decoded_video = None
|
||||
return_decoded_video = None,
|
||||
context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc - todo: handle this
|
||||
): # (b t n d) | (b c t h w)
|
||||
|
||||
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
||||
@ -1270,18 +1271,23 @@ class DynamicsModel(Module):
|
||||
|
||||
latents = torch.empty((batch_size, 0, *latent_shape), device = self.device)
|
||||
|
||||
# while all the frames of the video (per latent) is not generated
|
||||
|
||||
while latents.shape[1] < time_steps:
|
||||
|
||||
curr_time_steps = latents.shape[1]
|
||||
noised_latent = torch.randn((batch_size, 1, *latent_shape), device = self.device)
|
||||
|
||||
for step in range(num_steps):
|
||||
signal_level = tensor(step * step_size, device = self.device)
|
||||
signal_levels = torch.full((batch_size, 1), step * step_size, device = self.device)
|
||||
|
||||
noised_latent_with_context, pack_context_shape = pack((latents, noised_latent), 'b * n d')
|
||||
|
||||
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1) # todo - handle noising of past
|
||||
|
||||
pred = self.forward(
|
||||
latents = noised_latent_with_context,
|
||||
signal_levels = signal_level,
|
||||
signal_levels = signal_levels_with_context,
|
||||
step_sizes_log2 = step_size_log2,
|
||||
return_pred_only = True
|
||||
)
|
||||
@ -1291,7 +1297,7 @@ class DynamicsModel(Module):
|
||||
# derive flow, based on whether in x-space or not
|
||||
|
||||
if self.pred_orig_latent:
|
||||
times = self.get_times_from_signal_level(signal_level, noised_latent)
|
||||
times = self.get_times_from_signal_level(signal_levels, noised_latent)
|
||||
flow = (pred - noised_latent) / (1. - times)
|
||||
else:
|
||||
flow = pred
|
||||
@ -1350,51 +1356,60 @@ class DynamicsModel(Module):
|
||||
|
||||
# shape related
|
||||
|
||||
if exists(signal_levels) and signal_levels.ndim == 0:
|
||||
signal_levels = repeat(signal_levels, '-> b', b = batch)
|
||||
if exists(signal_levels):
|
||||
if signal_levels.ndim == 0:
|
||||
signal_levels = repeat(signal_levels, '-> b', b = batch)
|
||||
|
||||
if exists(step_sizes_log2):
|
||||
if step_sizes_log2.ndim == 0:
|
||||
step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch)
|
||||
if signal_levels.ndim == 1:
|
||||
signal_levels = repeat(signal_levels, 'b -> b t', t = time)
|
||||
|
||||
if step_sizes_log2.ndim == 1:
|
||||
step_sizes_log2 = repeat(step_sizes_log2, 'b -> b t', t = time)
|
||||
if exists(step_sizes_log2) and step_sizes_log2.ndim == 0:
|
||||
step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch)
|
||||
|
||||
# flow related
|
||||
|
||||
assert not (exists(signal_levels) ^ exists(step_sizes_log2))
|
||||
|
||||
# if neither signal levels or step sizes passed in
|
||||
is_inference = exists(signal_levels)
|
||||
return_pred_only = is_inference
|
||||
|
||||
# if neither signal levels or step sizes passed in, assume training
|
||||
# generate them randomly for training
|
||||
|
||||
no_shortcut_train = random() < self.prob_no_shortcut_train
|
||||
if not is_inference:
|
||||
|
||||
if no_shortcut_train:
|
||||
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
|
||||
# in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this
|
||||
no_shortcut_train = random() < self.prob_no_shortcut_train
|
||||
|
||||
step_sizes_log2 = torch.zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
|
||||
signal_levels = torch.randint(0, self.max_steps, (batch, time), device = device)
|
||||
else:
|
||||
if no_shortcut_train:
|
||||
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
|
||||
# in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this
|
||||
|
||||
# now we follow eq (4)
|
||||
step_sizes_log2 = torch.zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
|
||||
signal_levels = torch.randint(0, self.max_steps, (batch, time), device = device)
|
||||
else:
|
||||
|
||||
step_sizes_log2 = torch.randint(1, self.num_step_sizes_log2, (batch,), device = device)
|
||||
num_step_sizes = 2 ** step_sizes_log2
|
||||
# now we follow eq (4)
|
||||
|
||||
signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
|
||||
step_sizes_log2 = torch.randint(1, self.num_step_sizes_log2, (batch,), device = device)
|
||||
num_step_sizes = 2 ** step_sizes_log2
|
||||
|
||||
# get the noise
|
||||
|
||||
noise = torch.randn_like(latents)
|
||||
signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
|
||||
|
||||
# times is from 0 to 1
|
||||
|
||||
times = self.get_times_from_signal_level(signal_levels, latents)
|
||||
|
||||
# noise from 0 as noise to 1 as data
|
||||
if not is_inference:
|
||||
# get the noise
|
||||
|
||||
noised_latents = noise.lerp(latents, times)
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
# noise from 0 as noise to 1 as data
|
||||
|
||||
noised_latents = noise.lerp(latents, times)
|
||||
|
||||
else:
|
||||
noised_latents = latents
|
||||
|
||||
# reinforcement learning related
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.1"
|
||||
version = "0.0.2"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user