add the discretized signal level + step size embeddings necessary for diffusion forcing + shortcut

This commit is contained in:
lucidrains 2025-10-02 07:39:34 -07:00
parent bb7a5d1680
commit 8b66b703e0
3 changed files with 45 additions and 10 deletions

View File

@ -9,6 +9,8 @@ import torch.nn.functional as F
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import cat, stack, arange, tensor, Tensor, is_tensor
from x_mlps_pytorch import create_mlp
from accelerate import Accelerator
# ein related
@ -460,7 +462,8 @@ class VideoTokenizer(Module):
latents = self.encoded_to_latents(tokens)
if return_latents:
return latents
latents = inverse_pack_time(latents)
return latents[..., -1, :]
tokens = self.latents_to_decoder(latents)
@ -501,6 +504,8 @@ class DynamicsModel(Module):
self,
dim,
dim_latent,
num_signal_levels = 500,
num_step_sizes = 32,
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
depth = 4,
@ -522,6 +527,14 @@ class DynamicsModel(Module):
self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
# signal and step sizes
assert divisible_by(dim, 2)
dim_half = dim // 2
self.signal_levels_embed = nn.Embedding(num_signal_levels, dim_half)
self.step_sizes_embed = nn.Embedding(num_step_sizes, dim_half)
# they sum all the actions into a single token
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
@ -555,19 +568,36 @@ class DynamicsModel(Module):
def forward(
self,
latents # (b t d)
latents, # (b t d)
signal_levels = None, # (b t)
step_sizes = None # (b t)
):
space_tokens = self.latents_to_spatial_tokens(latents)
# pack to tokens
# [latent space tokens] [register] [actions / agent]
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
registers = repeat(self.register_tokens, 's d -> b t s d', b = latents.shape[0], t = latents.shape[1])
agent_token = repeat(self.action_learned_embed, 'd -> b t 1 d', b = latents.shape[0], t = latents.shape[1])
agent_token = repeat(self.action_learned_embed, 'd -> b t d', b = latents.shape[0], t = latents.shape[1])
tokens, packed_tokens_shape = pack([space_tokens, registers, agent_token], 'b t * d')
# determine signal + step size embed for their diffusion forcing + shortcut
assert not (exists(signal_levels) ^ exists(step_sizes))
if exists(signal_levels):
signal_embed = self.signal_levels_embed(signal_levels)
step_size_embed = self.step_sizes_embed(step_sizes)
flow_token = cat((signal_embed, step_size_embed), dim = -1)
flow_token = rearrange(flow_token, 'b t d -> b t d')
else:
flow_token = registers[..., 0:0, :]
# pack to tokens for attending
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d')
# attention
@ -583,7 +613,7 @@ class DynamicsModel(Module):
# unpack
space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d')
flow_token, space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d')
# pooling

View File

@ -29,7 +29,8 @@ dependencies = [
"accelerate",
"einx>=0.3.0",
"einops>=0.8.1",
"torch>=2.4"
"torch>=2.4",
"x-mlps-pytorch"
]
[project.urls]

View File

@ -5,7 +5,7 @@ def test_e2e():
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32)
x = torch.randn(1, 3, 4, 256, 256)
x = torch.randn(2, 3, 4, 256, 256)
loss = tokenizer(x)
assert loss.numel() == 1
@ -13,6 +13,10 @@ def test_e2e():
latents = tokenizer(x, return_latents = True)
assert latents.shape[-1] == 32
dynamics = DynamicsModel(512, dim_latent = 32)
pred = dynamics(latents)
dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32)
signal_levels = torch.randint(0, 500, (2, 4))
step_sizes = torch.randint(0, 32, (2, 4))
pred = dynamics(latents, signal_levels = signal_levels, step_sizes = step_sizes)
assert pred.shape == latents.shape