sketch out the axial space time transformer in dynamics model

This commit is contained in:
lucidrains 2025-10-02 07:17:58 -07:00
parent 0285bba821
commit bb7a5d1680
2 changed files with 108 additions and 23 deletions

View File

@ -264,6 +264,8 @@ class Attention(Module):
kv_cache = None,
return_kv_cache = False
):
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
tokens = self.norm(tokens)
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
@ -311,6 +313,8 @@ class Attention(Module):
out = self.to_out(out)
out = inverse_packed_batch(out)
if not return_kv_cache:
return out
@ -351,7 +355,10 @@ class VideoTokenizer(Module):
patch_size,
encoder_depth = 4,
decoder_depth = 4,
attn_kwargs: dict = dict(),
attn_kwargs: dict = dict(
dim_head = 64,
heads = 8,
),
ff_kwargs: dict = dict(),
channels = 3
):
@ -487,14 +494,103 @@ class VideoTokenizer(Module):
return recon_loss
# dynamics model
# dynamics model, axial space-time transformer
class DynamicsModel(Module):
def __init__(
self
self,
dim,
dim_latent,
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
depth = 4,
time_block_every = 4, # every 4th block is time
attn_kwargs: dict = dict(
dim_head = 64,
heads = 8,
),
ff_kwargs: dict = dict()
):
super().__init__()
# spatial and register tokens
self.latents_to_spatial_tokens = Sequential(
Linear(dim_latent, dim * num_spatial_tokens),
Rearrange('... (tokens d) -> ... tokens d', tokens = num_spatial_tokens)
)
self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
# they sum all the actions into a single token
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
# transformer
layers = []
for i in range(depth):
layer_index = i + 1
is_time_block = divisible_by(layer_index, time_block_every)
rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity()
rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity()
layers.append(ModuleList([
rearrange_to_attend,
rearrange_from_attend,
Attention(dim = dim, **attn_kwargs),
SwiGLUFeedforward(dim = dim, **ff_kwargs)
]))
self.layers = ModuleList(layers)
# to prediction
self.to_pred = Sequential(
RMSNorm(dim),
Linear(dim, dim_latent)
)
def forward(
self,
latents # (b t d)
):
space_tokens = self.latents_to_spatial_tokens(latents)
# pack to tokens
# [latent space tokens] [register] [actions / agent]
registers = repeat(self.register_tokens, 's d -> b t s d', b = latents.shape[0], t = latents.shape[1])
agent_token = repeat(self.action_learned_embed, 'd -> b t 1 d', b = latents.shape[0], t = latents.shape[1])
tokens, packed_tokens_shape = pack([space_tokens, registers, agent_token], 'b t * d')
# attention
for pre_attn_rearrange, post_attn_rearrange, attn, ff in self.layers:
tokens = pre_attn_rearrange(tokens)
tokens = attn(tokens) + tokens
tokens = post_attn_rearrange(tokens)
tokens = ff(tokens) + tokens
# unpack
space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d')
# pooling
pooled = reduce(space_tokens, 'b t s d -> b t d', 'mean')
return self.to_pred(pooled)
# dreamer
class Dreamer(Module):

View File

@ -1,29 +1,18 @@
import pytest
import torch
def test_attn():
from dreamer4.dreamer4 import Attention
def test_e2e():
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
x = torch.randn(1, 1024, 512)
attn = Attention(512)
assert attn(x).shape == x.shape
def test_ff():
from dreamer4.dreamer4 import SwiGLUFeedforward
x = torch.randn(1, 1024, 512)
ff = SwiGLUFeedforward(512)
assert ff(x).shape == x.shape
def test_tokenizer():
from dreamer4.dreamer4 import VideoTokenizer
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 16)
x = torch.randn(1, 3, 16, 256, 256)
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32)
x = torch.randn(1, 3, 4, 256, 256)
loss = tokenizer(x)
assert loss.numel() == 1
latents = tokenizer(x, return_latents = True)
assert latents.shape[-1] == 32
assert latents.shape[-1] == 32
dynamics = DynamicsModel(512, dim_latent = 32)
pred = dynamics(latents)
assert pred.shape == latents.shape