a researcher in discord pointed out that the tokenizer also uses the axial space time transformer. redo without the 3d rotary and block causal, greatly simplifying the implementation

This commit is contained in:
lucidrains 2025-10-16 09:40:14 -07:00
parent 2ccb290e26
commit d74f09f0b3
3 changed files with 158 additions and 290 deletions

View File

@ -19,12 +19,3 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
url = {https://arxiv.org/abs/2509.24527},
}
```
```bibtex
@misc{xiong2025ndrope,
author = {Jerry Xiong},
title = {On n-dimensional rotary positional embeddings},
year = {2025},
url = {https://jerryxio.ng/posts/nd-rope/}
}
```

View File

@ -647,64 +647,7 @@ def calc_gae(
return returns
# golden gate rotary - Jerry Xiong, PhD student at UIUC
# https://jerryxio.ng/posts/nd-rope/
def _phi(m):
x = 2.
for _ in range(10):
x = (1. + x) ** (1. / (m + 1.))
return x
def make_directions(n, d):
g = _phi(d)
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
z = torch.fmod(i * alpha, 1.0)
directions = torch.erfinv(2.0 * z - 1.0)
directions = l2norm(directions)
return directions.float()
class GoldenGateRoPENd(Module):
def __init__(
self,
dim_pos,
heads,
dim_head,
rope_min_freq = 1.,
rope_max_freq = 10000.,
rope_p_zero_freqs = 0., # proportion of frequencies set to 0
):
super().__init__()
assert divisible_by(dim_head, 2)
n_freqs = dim_head // 2
n_zero_freqs = round(rope_p_zero_freqs * n_freqs)
omega = cat((
torch.zeros(n_zero_freqs),
rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
))
directions = make_directions(heads * n_freqs, dim_pos)
directions = rearrange(directions, '(h f) p -> h f p', h = heads)
omega_expanded = rearrange(omega, 'f -> f 1')
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
def forward(
self,
pos # (b n p)
):
freqs = rearrange(self.freqs, 'h f p -> h 1 f p')
positions = rearrange(pos.float(), 'n p -> 1 n 1 p')
# thetas for freqs and positions (batch, head, seq, freq)
theta = reduce(freqs * positions, 'h n f p -> h n f', 'sum')
return cat((theta, theta), dim = -1)
# rotary embeddings for time
class Rotary1D(Module):
def __init__(
@ -1070,6 +1013,123 @@ class SwiGLUFeedforward(Module):
return self.proj_out(x)
# axial space time transformer
class AxialSpaceTimeTransformer(Module):
def __init__(
self,
dim,
depth,
attn_dim_head = 64,
attn_softclamp_value = 50.,
time_block_every = 4,
attn_kwargs: dict = dict(),
ff_kwargs: dict = dict(),
num_residual_streams = 1,
num_special_spatial_tokens = 1,
special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
final_norm = True
):
super().__init__()
# hyper connections
hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
# attention
self.attn_softclamp_value = attn_softclamp_value
# attention masking
self.special_attend_only_itself = special_attend_only_itself
# time rotary embedding
self.time_rotary = Rotary1D(attn_dim_head)
# transformer
layers = []
is_time = []
for i in range(depth):
layer_index = i + 1
is_time_block = divisible_by(layer_index, time_block_every)
is_time.append(is_time_block)
rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity()
rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity()
layers.append(ModuleList([
rearrange_to_attend,
rearrange_from_attend,
hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
]))
self.layers = ModuleList(layers)
self.is_time = is_time
# final norm
self.final_norm = nn.RMSNorm(dim) if final_norm else nn.Identity()
# special tokens
self.num_special_spatial_tokens = num_special_spatial_tokens
def forward(
self,
tokens # (b t s d)
):
batch, time, space_seq_len, _, device = *tokens.shape, tokens.device
assert tokens.ndim == 4
# attend functions for space and time
use_flex = exists(flex_attention) and tokens.is_cuda
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = self.num_special_spatial_tokens, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
# rotary
rotary_pos_emb = self.time_rotary(time)
# attention
tokens = self.expand_streams(tokens)
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
tokens = pre_attn_rearrange(tokens)
# when is a axial time attention block, should be causal
attend_fn = time_attend if layer_is_time else space_attend
layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
# attention layer
tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens
tokens = post_attn_rearrange(tokens)
# feedforward layer
tokens = ff(tokens) + tokens
tokens = self.reduce_streams(tokens)
return self.final_norm(tokens)
# video tokenizer
class VideoTokenizer(Module):
@ -1083,6 +1143,7 @@ class VideoTokenizer(Module):
num_latent_tokens = 4,
encoder_depth = 4,
decoder_depth = 4,
time_block_every = 4,
attn_kwargs: dict = dict(),
attn_dim_head = 64,
attn_heads = 8,
@ -1133,32 +1194,19 @@ class VideoTokenizer(Module):
Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
)
# 3d rotations
# encoder space / time transformer
self.spacetime_rotary = GoldenGateRoPENd(
dim_pos = 3,
heads = attn_heads,
dim_head = attn_dim_head,
**nd_rotary_kwargs
self.encoder_transformer = AxialSpaceTimeTransformer(
dim = dim,
depth = encoder_depth,
attn_dim_head = attn_dim_head,
attn_softclamp_value = attn_softclamp_value,
time_block_every = time_block_every,
num_special_spatial_tokens = num_latent_tokens,
num_residual_streams = num_residual_streams,
final_norm = True
)
# attention related
self.attn_softclamp_value = attn_softclamp_value
# encoder
encoder_layers = []
for _ in range(encoder_depth):
encoder_layers.append(ModuleList([
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs)),
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
]))
self.encoder_layers = ModuleList(encoder_layers)
self.encoder_norm = RMSNorm(dim)
# latents
self.encoded_to_latents = Sequential(
@ -1182,16 +1230,18 @@ class VideoTokenizer(Module):
depth = decoder_pos_mlp_depth,
)
decoder_layers = []
# decoder transformer
for _ in range(decoder_depth):
decoder_layers.append(ModuleList([
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs)),
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
]))
self.decoder_layers = ModuleList(decoder_layers)
self.decoder_norm = RMSNorm(dim)
self.decoder_transformer = AxialSpaceTimeTransformer(
dim = dim,
depth = decoder_depth,
attn_dim_head = attn_dim_head,
attn_softclamp_value = attn_softclamp_value,
time_block_every = time_block_every,
num_special_spatial_tokens = num_latent_tokens,
num_residual_streams = num_residual_streams,
final_norm = True
)
# loss related
@ -1215,36 +1265,11 @@ class VideoTokenizer(Module):
self.eval()
return self.forward(video, return_latents = True)
def get_rotary_pos_emb(
self,
time,
num_patch_height,
num_patch_width
):
device = self.device
positions = stack(torch.meshgrid(
arange(time, device = device),
arange(num_patch_height, device = device),
arange(num_patch_width, device = device)
), dim = -1)
positions = rearrange(positions, 't h w p -> t (h w) p')
# give the latents an out of bounds position and assume the network will figure it out
positions = pad_at_dim(positions, (0, self.num_latent_tokens), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated
positions = rearrange(positions, 't hw p -> (t hw) p')
return self.spacetime_rotary(positions)
def decode(
self,
latents, # (b t n d)
height = None,
width = None,
rotary_pos_emb = None
): # (b c t h w)
height = default(height, self.image_height)
@ -1259,9 +1284,6 @@ class VideoTokenizer(Module):
num_patch_height = height // self.patch_size
num_patch_width = width // self.patch_size
if not exists(rotary_pos_emb):
rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width)
# latents to tokens
latent_tokens = self.latents_to_decoder(latents)
@ -1278,43 +1300,9 @@ class VideoTokenizer(Module):
tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d')
space_seq_len = tokens.shape[-2]
# pack time
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
seq_len = tokens.shape[-2]
# decoder attend
decoder_attend_fn = get_attend_fn(
use_flex,
seq_len, seq_len,
causal = True,
causal_block_size = space_seq_len,
softclamp_value = self.attn_softclamp_value,
block_size_per_special = space_seq_len,
num_special_tokens = self.num_latent_tokens,
special_attend_only_itself = True # different than encoder
)
# decoder attention
tokens = self.expand_streams(tokens)
for attn, ff in self.decoder_layers:
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn)
tokens = ff(tokens)
tokens = self.reduce_streams(tokens)
tokens = self.decoder_norm(tokens)
# unpack time
tokens = inverse_pack_time(tokens)
tokens = self.decoder_transformer(tokens)
# unpack latents
@ -1346,10 +1334,6 @@ class VideoTokenizer(Module):
num_patch_height, num_patch_width, _ = tokens.shape[-3:]
# rotary positions
rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width)
# masking
mask_patches = default(mask_patches, self.training)
@ -1374,50 +1358,12 @@ class VideoTokenizer(Module):
tokens, packed_latent_shape = pack((tokens, latents), 'b t * d')
space_seq_len = tokens.shape[-2]
# encoder attention
# pack time
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
seq_len = tokens.shape[1]
# attend hyper parameters
use_flex = tokens.is_cuda and exists(flex_attention)
# encoder attend
# modality can only attend to itself while latents can attend to everything
# similar to agent token in dynamics model
encoder_attend_fn = get_attend_fn(
use_flex,
seq_len, seq_len,
causal = True,
causal_block_size = space_seq_len,
softclamp_value = self.attn_softclamp_value,
block_size_per_special = space_seq_len,
num_special_tokens = self.num_latent_tokens,
special_attend_only_itself = False # different than decoder
)
# encoder
tokens = self.expand_streams(tokens)
for attn, ff in self.encoder_layers:
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn)
tokens = ff(tokens)
tokens = self.reduce_streams(tokens)
tokens = self.encoder_norm(tokens)
tokens = self.encoder_transformer(tokens)
# latent bottleneck
tokens = inverse_pack_time(tokens)
tokens, latents = unpack(tokens, packed_latent_shape, 'b t * d')
latents = self.encoded_to_latents(latents)
@ -1425,7 +1371,7 @@ class VideoTokenizer(Module):
if return_latents:
return latents
recon_video = self.decode(latents, height = height, width = width, rotary_pos_emb = rotary_pos_emb)
recon_video = self.decode(latents, height = height, width = width)
# losses
@ -1491,10 +1437,6 @@ class DynamicsWorldModel(Module):
):
super().__init__()
# hyper connections
hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
# can accept raw video if tokenizer is passed in
self.video_tokenizer = video_tokenizer
@ -1635,37 +1577,19 @@ class DynamicsWorldModel(Module):
depth = value_head_mlp_depth,
)
# attention
# efficient axial space / time transformer
self.attn_softclamp_value = attn_softclamp_value
# time rotary embedding
self.time_rotary = Rotary1D(attn_dim_head)
# transformer
layers = []
is_time = []
for i in range(depth):
layer_index = i + 1
is_time_block = divisible_by(layer_index, time_block_every)
is_time.append(is_time_block)
rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity()
rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity()
layers.append(ModuleList([
rearrange_to_attend,
rearrange_from_attend,
hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
]))
self.layers = ModuleList(layers)
self.is_time = is_time
self.transformer = AxialSpaceTimeTransformer(
dim = dim,
depth = depth,
attn_dim_head = attn_dim_head,
attn_softclamp_value = attn_softclamp_value,
attn_kwargs = attn_kwargs,
ff_kwargs = ff_kwargs,
num_residual_streams = num_residual_streams,
num_special_spatial_tokens = num_agents,
final_norm = False
)
# zero
@ -2046,56 +1970,9 @@ class DynamicsWorldModel(Module):
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
# attend functions for space and time
seq_len = tokens.shape[1]
use_flex = exists(flex_attention) and tokens.is_cuda
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device)
space_seq_len = (
+ 1 # signal + step
+ num_action_tokens # past action tokens - todo: account for multi-agent
+ num_reward_tokens # maybe allow world model being fine-tuned in phase 3 to see rewards as state
+ self.num_agents # action / agent tokens
+ self.num_register_tokens
+ num_spatial_tokens
)
space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = self.num_agents, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
# rotary
rotary_pos_emb = self.time_rotary(time)
# attention
tokens = self.expand_streams(tokens)
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
tokens = pre_attn_rearrange(tokens)
# when is a axial time attention block, should be causal
attend_fn = time_attend if layer_is_time else space_attend
layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
# attention layer
tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens
tokens = post_attn_rearrange(tokens)
# feedforward layer
tokens = ff(tokens) + tokens
tokens = self.reduce_streams(tokens)
tokens = self.transformer(tokens)
# unpack

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.23"
version = "0.0.24"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }