add the discretized signal level + step size embeddings necessary for diffusion forcing + shortcut
This commit is contained in:
parent
bb7a5d1680
commit
8b66b703e0
@ -9,6 +9,8 @@ import torch.nn.functional as F
|
|||||||
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||||
from torch import cat, stack, arange, tensor, Tensor, is_tensor
|
from torch import cat, stack, arange, tensor, Tensor, is_tensor
|
||||||
|
|
||||||
|
from x_mlps_pytorch import create_mlp
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
# ein related
|
# ein related
|
||||||
@ -460,7 +462,8 @@ class VideoTokenizer(Module):
|
|||||||
latents = self.encoded_to_latents(tokens)
|
latents = self.encoded_to_latents(tokens)
|
||||||
|
|
||||||
if return_latents:
|
if return_latents:
|
||||||
return latents
|
latents = inverse_pack_time(latents)
|
||||||
|
return latents[..., -1, :]
|
||||||
|
|
||||||
tokens = self.latents_to_decoder(latents)
|
tokens = self.latents_to_decoder(latents)
|
||||||
|
|
||||||
@ -501,6 +504,8 @@ class DynamicsModel(Module):
|
|||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
dim_latent,
|
dim_latent,
|
||||||
|
num_signal_levels = 500,
|
||||||
|
num_step_sizes = 32,
|
||||||
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
|
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
|
||||||
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
||||||
depth = 4,
|
depth = 4,
|
||||||
@ -522,6 +527,14 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
||||||
|
|
||||||
|
# signal and step sizes
|
||||||
|
|
||||||
|
assert divisible_by(dim, 2)
|
||||||
|
dim_half = dim // 2
|
||||||
|
|
||||||
|
self.signal_levels_embed = nn.Embedding(num_signal_levels, dim_half)
|
||||||
|
self.step_sizes_embed = nn.Embedding(num_step_sizes, dim_half)
|
||||||
|
|
||||||
# they sum all the actions into a single token
|
# they sum all the actions into a single token
|
||||||
|
|
||||||
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
|
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
|
||||||
@ -555,19 +568,36 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
latents # (b t d)
|
latents, # (b t d)
|
||||||
|
signal_levels = None, # (b t)
|
||||||
|
step_sizes = None # (b t)
|
||||||
):
|
):
|
||||||
|
|
||||||
space_tokens = self.latents_to_spatial_tokens(latents)
|
space_tokens = self.latents_to_spatial_tokens(latents)
|
||||||
|
|
||||||
# pack to tokens
|
# pack to tokens
|
||||||
# [latent space tokens] [register] [actions / agent]
|
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
|
||||||
|
|
||||||
registers = repeat(self.register_tokens, 's d -> b t s d', b = latents.shape[0], t = latents.shape[1])
|
registers = repeat(self.register_tokens, 's d -> b t s d', b = latents.shape[0], t = latents.shape[1])
|
||||||
|
|
||||||
agent_token = repeat(self.action_learned_embed, 'd -> b t 1 d', b = latents.shape[0], t = latents.shape[1])
|
agent_token = repeat(self.action_learned_embed, 'd -> b t d', b = latents.shape[0], t = latents.shape[1])
|
||||||
|
|
||||||
tokens, packed_tokens_shape = pack([space_tokens, registers, agent_token], 'b t * d')
|
# determine signal + step size embed for their diffusion forcing + shortcut
|
||||||
|
|
||||||
|
assert not (exists(signal_levels) ^ exists(step_sizes))
|
||||||
|
|
||||||
|
if exists(signal_levels):
|
||||||
|
signal_embed = self.signal_levels_embed(signal_levels)
|
||||||
|
step_size_embed = self.step_sizes_embed(step_sizes)
|
||||||
|
|
||||||
|
flow_token = cat((signal_embed, step_size_embed), dim = -1)
|
||||||
|
flow_token = rearrange(flow_token, 'b t d -> b t d')
|
||||||
|
else:
|
||||||
|
flow_token = registers[..., 0:0, :]
|
||||||
|
|
||||||
|
# pack to tokens for attending
|
||||||
|
|
||||||
|
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d')
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
@ -583,7 +613,7 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
# unpack
|
# unpack
|
||||||
|
|
||||||
space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d')
|
flow_token, space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||||
|
|
||||||
# pooling
|
# pooling
|
||||||
|
|
||||||
|
|||||||
@ -29,7 +29,8 @@ dependencies = [
|
|||||||
"accelerate",
|
"accelerate",
|
||||||
"einx>=0.3.0",
|
"einx>=0.3.0",
|
||||||
"einops>=0.8.1",
|
"einops>=0.8.1",
|
||||||
"torch>=2.4"
|
"torch>=2.4",
|
||||||
|
"x-mlps-pytorch"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|||||||
@ -5,7 +5,7 @@ def test_e2e():
|
|||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
||||||
|
|
||||||
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32)
|
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32)
|
||||||
x = torch.randn(1, 3, 4, 256, 256)
|
x = torch.randn(2, 3, 4, 256, 256)
|
||||||
|
|
||||||
loss = tokenizer(x)
|
loss = tokenizer(x)
|
||||||
assert loss.numel() == 1
|
assert loss.numel() == 1
|
||||||
@ -13,6 +13,10 @@ def test_e2e():
|
|||||||
latents = tokenizer(x, return_latents = True)
|
latents = tokenizer(x, return_latents = True)
|
||||||
assert latents.shape[-1] == 32
|
assert latents.shape[-1] == 32
|
||||||
|
|
||||||
dynamics = DynamicsModel(512, dim_latent = 32)
|
dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32)
|
||||||
pred = dynamics(latents)
|
|
||||||
|
signal_levels = torch.randint(0, 500, (2, 4))
|
||||||
|
step_sizes = torch.randint(0, 32, (2, 4))
|
||||||
|
|
||||||
|
pred = dynamics(latents, signal_levels = signal_levels, step_sizes = step_sizes)
|
||||||
assert pred.shape == latents.shape
|
assert pred.shape == latents.shape
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user