sketch out the axial space time transformer in dynamics model
This commit is contained in:
parent
0285bba821
commit
bb7a5d1680
@ -264,6 +264,8 @@ class Attention(Module):
|
|||||||
kv_cache = None,
|
kv_cache = None,
|
||||||
return_kv_cache = False
|
return_kv_cache = False
|
||||||
):
|
):
|
||||||
|
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
||||||
|
|
||||||
tokens = self.norm(tokens)
|
tokens = self.norm(tokens)
|
||||||
|
|
||||||
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
|
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
|
||||||
@ -311,6 +313,8 @@ class Attention(Module):
|
|||||||
|
|
||||||
out = self.to_out(out)
|
out = self.to_out(out)
|
||||||
|
|
||||||
|
out = inverse_packed_batch(out)
|
||||||
|
|
||||||
if not return_kv_cache:
|
if not return_kv_cache:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -351,7 +355,10 @@ class VideoTokenizer(Module):
|
|||||||
patch_size,
|
patch_size,
|
||||||
encoder_depth = 4,
|
encoder_depth = 4,
|
||||||
decoder_depth = 4,
|
decoder_depth = 4,
|
||||||
attn_kwargs: dict = dict(),
|
attn_kwargs: dict = dict(
|
||||||
|
dim_head = 64,
|
||||||
|
heads = 8,
|
||||||
|
),
|
||||||
ff_kwargs: dict = dict(),
|
ff_kwargs: dict = dict(),
|
||||||
channels = 3
|
channels = 3
|
||||||
):
|
):
|
||||||
@ -487,14 +494,103 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
return recon_loss
|
return recon_loss
|
||||||
|
|
||||||
# dynamics model
|
# dynamics model, axial space-time transformer
|
||||||
|
|
||||||
class DynamicsModel(Module):
|
class DynamicsModel(Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self
|
self,
|
||||||
|
dim,
|
||||||
|
dim_latent,
|
||||||
|
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
|
||||||
|
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
||||||
|
depth = 4,
|
||||||
|
time_block_every = 4, # every 4th block is time
|
||||||
|
attn_kwargs: dict = dict(
|
||||||
|
dim_head = 64,
|
||||||
|
heads = 8,
|
||||||
|
),
|
||||||
|
ff_kwargs: dict = dict()
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# spatial and register tokens
|
||||||
|
|
||||||
|
self.latents_to_spatial_tokens = Sequential(
|
||||||
|
Linear(dim_latent, dim * num_spatial_tokens),
|
||||||
|
Rearrange('... (tokens d) -> ... tokens d', tokens = num_spatial_tokens)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
||||||
|
|
||||||
|
# they sum all the actions into a single token
|
||||||
|
|
||||||
|
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
|
||||||
|
|
||||||
|
# transformer
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
for i in range(depth):
|
||||||
|
layer_index = i + 1
|
||||||
|
is_time_block = divisible_by(layer_index, time_block_every)
|
||||||
|
|
||||||
|
rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity()
|
||||||
|
rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity()
|
||||||
|
|
||||||
|
layers.append(ModuleList([
|
||||||
|
rearrange_to_attend,
|
||||||
|
rearrange_from_attend,
|
||||||
|
Attention(dim = dim, **attn_kwargs),
|
||||||
|
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||||
|
]))
|
||||||
|
|
||||||
|
self.layers = ModuleList(layers)
|
||||||
|
|
||||||
|
# to prediction
|
||||||
|
|
||||||
|
self.to_pred = Sequential(
|
||||||
|
RMSNorm(dim),
|
||||||
|
Linear(dim, dim_latent)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
latents # (b t d)
|
||||||
|
):
|
||||||
|
|
||||||
|
space_tokens = self.latents_to_spatial_tokens(latents)
|
||||||
|
|
||||||
|
# pack to tokens
|
||||||
|
# [latent space tokens] [register] [actions / agent]
|
||||||
|
|
||||||
|
registers = repeat(self.register_tokens, 's d -> b t s d', b = latents.shape[0], t = latents.shape[1])
|
||||||
|
|
||||||
|
agent_token = repeat(self.action_learned_embed, 'd -> b t 1 d', b = latents.shape[0], t = latents.shape[1])
|
||||||
|
|
||||||
|
tokens, packed_tokens_shape = pack([space_tokens, registers, agent_token], 'b t * d')
|
||||||
|
|
||||||
|
# attention
|
||||||
|
|
||||||
|
for pre_attn_rearrange, post_attn_rearrange, attn, ff in self.layers:
|
||||||
|
|
||||||
|
tokens = pre_attn_rearrange(tokens)
|
||||||
|
|
||||||
|
tokens = attn(tokens) + tokens
|
||||||
|
|
||||||
|
tokens = post_attn_rearrange(tokens)
|
||||||
|
|
||||||
|
tokens = ff(tokens) + tokens
|
||||||
|
|
||||||
|
# unpack
|
||||||
|
|
||||||
|
space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||||
|
|
||||||
|
# pooling
|
||||||
|
|
||||||
|
pooled = reduce(space_tokens, 'b t s d -> b t d', 'mean')
|
||||||
|
|
||||||
|
return self.to_pred(pooled)
|
||||||
|
|
||||||
# dreamer
|
# dreamer
|
||||||
|
|
||||||
class Dreamer(Module):
|
class Dreamer(Module):
|
||||||
|
|||||||
@ -1,29 +1,18 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def test_attn():
|
def test_e2e():
|
||||||
from dreamer4.dreamer4 import Attention
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
||||||
|
|
||||||
x = torch.randn(1, 1024, 512)
|
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32)
|
||||||
attn = Attention(512)
|
x = torch.randn(1, 3, 4, 256, 256)
|
||||||
|
|
||||||
assert attn(x).shape == x.shape
|
|
||||||
|
|
||||||
def test_ff():
|
|
||||||
from dreamer4.dreamer4 import SwiGLUFeedforward
|
|
||||||
x = torch.randn(1, 1024, 512)
|
|
||||||
ff = SwiGLUFeedforward(512)
|
|
||||||
|
|
||||||
assert ff(x).shape == x.shape
|
|
||||||
|
|
||||||
def test_tokenizer():
|
|
||||||
from dreamer4.dreamer4 import VideoTokenizer
|
|
||||||
|
|
||||||
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 16)
|
|
||||||
x = torch.randn(1, 3, 16, 256, 256)
|
|
||||||
|
|
||||||
loss = tokenizer(x)
|
loss = tokenizer(x)
|
||||||
assert loss.numel() == 1
|
assert loss.numel() == 1
|
||||||
|
|
||||||
latents = tokenizer(x, return_latents = True)
|
latents = tokenizer(x, return_latents = True)
|
||||||
assert latents.shape[-1] == 32
|
assert latents.shape[-1] == 32
|
||||||
|
|
||||||
|
dynamics = DynamicsModel(512, dim_latent = 32)
|
||||||
|
pred = dynamics(latents)
|
||||||
|
assert pred.shape == latents.shape
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user