a researcher in discord pointed out that the tokenizer also uses the axial space time transformer. redo without the 3d rotary and block causal, greatly simplifying the implementation
This commit is contained in:
parent
2ccb290e26
commit
d74f09f0b3
@ -19,12 +19,3 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
|
||||
url = {https://arxiv.org/abs/2509.24527},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{xiong2025ndrope,
|
||||
author = {Jerry Xiong},
|
||||
title = {On n-dimensional rotary positional embeddings},
|
||||
year = {2025},
|
||||
url = {https://jerryxio.ng/posts/nd-rope/}
|
||||
}
|
||||
```
|
||||
|
||||
@ -647,64 +647,7 @@ def calc_gae(
|
||||
|
||||
return returns
|
||||
|
||||
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
||||
# https://jerryxio.ng/posts/nd-rope/
|
||||
|
||||
def _phi(m):
|
||||
x = 2.
|
||||
for _ in range(10):
|
||||
x = (1. + x) ** (1. / (m + 1.))
|
||||
return x
|
||||
|
||||
def make_directions(n, d):
|
||||
g = _phi(d)
|
||||
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
|
||||
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
|
||||
z = torch.fmod(i * alpha, 1.0)
|
||||
directions = torch.erfinv(2.0 * z - 1.0)
|
||||
directions = l2norm(directions)
|
||||
return directions.float()
|
||||
|
||||
class GoldenGateRoPENd(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_pos,
|
||||
heads,
|
||||
dim_head,
|
||||
rope_min_freq = 1.,
|
||||
rope_max_freq = 10000.,
|
||||
rope_p_zero_freqs = 0., # proportion of frequencies set to 0
|
||||
):
|
||||
super().__init__()
|
||||
assert divisible_by(dim_head, 2)
|
||||
|
||||
n_freqs = dim_head // 2
|
||||
n_zero_freqs = round(rope_p_zero_freqs * n_freqs)
|
||||
|
||||
omega = cat((
|
||||
torch.zeros(n_zero_freqs),
|
||||
rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
|
||||
))
|
||||
|
||||
directions = make_directions(heads * n_freqs, dim_pos)
|
||||
directions = rearrange(directions, '(h f) p -> h f p', h = heads)
|
||||
|
||||
omega_expanded = rearrange(omega, 'f -> f 1')
|
||||
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pos # (b n p)
|
||||
):
|
||||
|
||||
freqs = rearrange(self.freqs, 'h f p -> h 1 f p')
|
||||
positions = rearrange(pos.float(), 'n p -> 1 n 1 p')
|
||||
|
||||
# thetas for freqs and positions (batch, head, seq, freq)
|
||||
|
||||
theta = reduce(freqs * positions, 'h n f p -> h n f', 'sum')
|
||||
|
||||
return cat((theta, theta), dim = -1)
|
||||
# rotary embeddings for time
|
||||
|
||||
class Rotary1D(Module):
|
||||
def __init__(
|
||||
@ -1070,6 +1013,123 @@ class SwiGLUFeedforward(Module):
|
||||
|
||||
return self.proj_out(x)
|
||||
|
||||
# axial space time transformer
|
||||
|
||||
class AxialSpaceTimeTransformer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
attn_dim_head = 64,
|
||||
attn_softclamp_value = 50.,
|
||||
time_block_every = 4,
|
||||
attn_kwargs: dict = dict(),
|
||||
ff_kwargs: dict = dict(),
|
||||
num_residual_streams = 1,
|
||||
num_special_spatial_tokens = 1,
|
||||
special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
|
||||
final_norm = True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# hyper connections
|
||||
|
||||
hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
|
||||
|
||||
# attention
|
||||
|
||||
self.attn_softclamp_value = attn_softclamp_value
|
||||
|
||||
# attention masking
|
||||
|
||||
self.special_attend_only_itself = special_attend_only_itself
|
||||
|
||||
# time rotary embedding
|
||||
|
||||
self.time_rotary = Rotary1D(attn_dim_head)
|
||||
|
||||
# transformer
|
||||
|
||||
layers = []
|
||||
is_time = []
|
||||
|
||||
for i in range(depth):
|
||||
layer_index = i + 1
|
||||
|
||||
is_time_block = divisible_by(layer_index, time_block_every)
|
||||
is_time.append(is_time_block)
|
||||
|
||||
rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity()
|
||||
rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity()
|
||||
|
||||
layers.append(ModuleList([
|
||||
rearrange_to_attend,
|
||||
rearrange_from_attend,
|
||||
hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
|
||||
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
||||
]))
|
||||
|
||||
self.layers = ModuleList(layers)
|
||||
self.is_time = is_time
|
||||
|
||||
# final norm
|
||||
|
||||
self.final_norm = nn.RMSNorm(dim) if final_norm else nn.Identity()
|
||||
|
||||
# special tokens
|
||||
|
||||
self.num_special_spatial_tokens = num_special_spatial_tokens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens # (b t s d)
|
||||
):
|
||||
batch, time, space_seq_len, _, device = *tokens.shape, tokens.device
|
||||
|
||||
assert tokens.ndim == 4
|
||||
|
||||
# attend functions for space and time
|
||||
|
||||
use_flex = exists(flex_attention) and tokens.is_cuda
|
||||
|
||||
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
|
||||
|
||||
space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = self.num_special_spatial_tokens, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
||||
|
||||
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
||||
|
||||
# rotary
|
||||
|
||||
rotary_pos_emb = self.time_rotary(time)
|
||||
|
||||
# attention
|
||||
|
||||
tokens = self.expand_streams(tokens)
|
||||
|
||||
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
||||
|
||||
tokens = pre_attn_rearrange(tokens)
|
||||
|
||||
# when is a axial time attention block, should be causal
|
||||
|
||||
attend_fn = time_attend if layer_is_time else space_attend
|
||||
|
||||
layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
|
||||
|
||||
# attention layer
|
||||
|
||||
tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens
|
||||
|
||||
tokens = post_attn_rearrange(tokens)
|
||||
|
||||
# feedforward layer
|
||||
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
tokens = self.reduce_streams(tokens)
|
||||
|
||||
return self.final_norm(tokens)
|
||||
|
||||
# video tokenizer
|
||||
|
||||
class VideoTokenizer(Module):
|
||||
@ -1083,6 +1143,7 @@ class VideoTokenizer(Module):
|
||||
num_latent_tokens = 4,
|
||||
encoder_depth = 4,
|
||||
decoder_depth = 4,
|
||||
time_block_every = 4,
|
||||
attn_kwargs: dict = dict(),
|
||||
attn_dim_head = 64,
|
||||
attn_heads = 8,
|
||||
@ -1133,32 +1194,19 @@ class VideoTokenizer(Module):
|
||||
Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
|
||||
)
|
||||
|
||||
# 3d rotations
|
||||
# encoder space / time transformer
|
||||
|
||||
self.spacetime_rotary = GoldenGateRoPENd(
|
||||
dim_pos = 3,
|
||||
heads = attn_heads,
|
||||
dim_head = attn_dim_head,
|
||||
**nd_rotary_kwargs
|
||||
self.encoder_transformer = AxialSpaceTimeTransformer(
|
||||
dim = dim,
|
||||
depth = encoder_depth,
|
||||
attn_dim_head = attn_dim_head,
|
||||
attn_softclamp_value = attn_softclamp_value,
|
||||
time_block_every = time_block_every,
|
||||
num_special_spatial_tokens = num_latent_tokens,
|
||||
num_residual_streams = num_residual_streams,
|
||||
final_norm = True
|
||||
)
|
||||
|
||||
# attention related
|
||||
|
||||
self.attn_softclamp_value = attn_softclamp_value
|
||||
|
||||
# encoder
|
||||
|
||||
encoder_layers = []
|
||||
|
||||
for _ in range(encoder_depth):
|
||||
encoder_layers.append(ModuleList([
|
||||
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs)),
|
||||
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
||||
]))
|
||||
|
||||
self.encoder_layers = ModuleList(encoder_layers)
|
||||
self.encoder_norm = RMSNorm(dim)
|
||||
|
||||
# latents
|
||||
|
||||
self.encoded_to_latents = Sequential(
|
||||
@ -1182,16 +1230,18 @@ class VideoTokenizer(Module):
|
||||
depth = decoder_pos_mlp_depth,
|
||||
)
|
||||
|
||||
decoder_layers = []
|
||||
# decoder transformer
|
||||
|
||||
for _ in range(decoder_depth):
|
||||
decoder_layers.append(ModuleList([
|
||||
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs)),
|
||||
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
||||
]))
|
||||
|
||||
self.decoder_layers = ModuleList(decoder_layers)
|
||||
self.decoder_norm = RMSNorm(dim)
|
||||
self.decoder_transformer = AxialSpaceTimeTransformer(
|
||||
dim = dim,
|
||||
depth = decoder_depth,
|
||||
attn_dim_head = attn_dim_head,
|
||||
attn_softclamp_value = attn_softclamp_value,
|
||||
time_block_every = time_block_every,
|
||||
num_special_spatial_tokens = num_latent_tokens,
|
||||
num_residual_streams = num_residual_streams,
|
||||
final_norm = True
|
||||
)
|
||||
|
||||
# loss related
|
||||
|
||||
@ -1215,36 +1265,11 @@ class VideoTokenizer(Module):
|
||||
self.eval()
|
||||
return self.forward(video, return_latents = True)
|
||||
|
||||
def get_rotary_pos_emb(
|
||||
self,
|
||||
time,
|
||||
num_patch_height,
|
||||
num_patch_width
|
||||
):
|
||||
device = self.device
|
||||
|
||||
positions = stack(torch.meshgrid(
|
||||
arange(time, device = device),
|
||||
arange(num_patch_height, device = device),
|
||||
arange(num_patch_width, device = device)
|
||||
), dim = -1)
|
||||
|
||||
positions = rearrange(positions, 't h w p -> t (h w) p')
|
||||
|
||||
# give the latents an out of bounds position and assume the network will figure it out
|
||||
|
||||
positions = pad_at_dim(positions, (0, self.num_latent_tokens), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated
|
||||
|
||||
positions = rearrange(positions, 't hw p -> (t hw) p')
|
||||
|
||||
return self.spacetime_rotary(positions)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
latents, # (b t n d)
|
||||
height = None,
|
||||
width = None,
|
||||
rotary_pos_emb = None
|
||||
): # (b c t h w)
|
||||
|
||||
height = default(height, self.image_height)
|
||||
@ -1259,9 +1284,6 @@ class VideoTokenizer(Module):
|
||||
num_patch_height = height // self.patch_size
|
||||
num_patch_width = width // self.patch_size
|
||||
|
||||
if not exists(rotary_pos_emb):
|
||||
rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width)
|
||||
|
||||
# latents to tokens
|
||||
|
||||
latent_tokens = self.latents_to_decoder(latents)
|
||||
@ -1278,43 +1300,9 @@ class VideoTokenizer(Module):
|
||||
|
||||
tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d')
|
||||
|
||||
space_seq_len = tokens.shape[-2]
|
||||
|
||||
# pack time
|
||||
|
||||
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
|
||||
|
||||
seq_len = tokens.shape[-2]
|
||||
|
||||
# decoder attend
|
||||
|
||||
decoder_attend_fn = get_attend_fn(
|
||||
use_flex,
|
||||
seq_len, seq_len,
|
||||
causal = True,
|
||||
causal_block_size = space_seq_len,
|
||||
softclamp_value = self.attn_softclamp_value,
|
||||
block_size_per_special = space_seq_len,
|
||||
num_special_tokens = self.num_latent_tokens,
|
||||
special_attend_only_itself = True # different than encoder
|
||||
)
|
||||
|
||||
# decoder attention
|
||||
|
||||
tokens = self.expand_streams(tokens)
|
||||
|
||||
for attn, ff in self.decoder_layers:
|
||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn)
|
||||
|
||||
tokens = ff(tokens)
|
||||
|
||||
tokens = self.reduce_streams(tokens)
|
||||
|
||||
tokens = self.decoder_norm(tokens)
|
||||
|
||||
# unpack time
|
||||
|
||||
tokens = inverse_pack_time(tokens)
|
||||
tokens = self.decoder_transformer(tokens)
|
||||
|
||||
# unpack latents
|
||||
|
||||
@ -1346,10 +1334,6 @@ class VideoTokenizer(Module):
|
||||
|
||||
num_patch_height, num_patch_width, _ = tokens.shape[-3:]
|
||||
|
||||
# rotary positions
|
||||
|
||||
rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width)
|
||||
|
||||
# masking
|
||||
|
||||
mask_patches = default(mask_patches, self.training)
|
||||
@ -1374,50 +1358,12 @@ class VideoTokenizer(Module):
|
||||
|
||||
tokens, packed_latent_shape = pack((tokens, latents), 'b t * d')
|
||||
|
||||
space_seq_len = tokens.shape[-2]
|
||||
# encoder attention
|
||||
|
||||
# pack time
|
||||
|
||||
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
|
||||
|
||||
seq_len = tokens.shape[1]
|
||||
|
||||
# attend hyper parameters
|
||||
|
||||
use_flex = tokens.is_cuda and exists(flex_attention)
|
||||
|
||||
# encoder attend
|
||||
|
||||
# modality can only attend to itself while latents can attend to everything
|
||||
# similar to agent token in dynamics model
|
||||
|
||||
encoder_attend_fn = get_attend_fn(
|
||||
use_flex,
|
||||
seq_len, seq_len,
|
||||
causal = True,
|
||||
causal_block_size = space_seq_len,
|
||||
softclamp_value = self.attn_softclamp_value,
|
||||
block_size_per_special = space_seq_len,
|
||||
num_special_tokens = self.num_latent_tokens,
|
||||
special_attend_only_itself = False # different than decoder
|
||||
)
|
||||
|
||||
# encoder
|
||||
|
||||
tokens = self.expand_streams(tokens)
|
||||
|
||||
for attn, ff in self.encoder_layers:
|
||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn)
|
||||
tokens = ff(tokens)
|
||||
|
||||
tokens = self.reduce_streams(tokens)
|
||||
|
||||
tokens = self.encoder_norm(tokens)
|
||||
tokens = self.encoder_transformer(tokens)
|
||||
|
||||
# latent bottleneck
|
||||
|
||||
tokens = inverse_pack_time(tokens)
|
||||
|
||||
tokens, latents = unpack(tokens, packed_latent_shape, 'b t * d')
|
||||
|
||||
latents = self.encoded_to_latents(latents)
|
||||
@ -1425,7 +1371,7 @@ class VideoTokenizer(Module):
|
||||
if return_latents:
|
||||
return latents
|
||||
|
||||
recon_video = self.decode(latents, height = height, width = width, rotary_pos_emb = rotary_pos_emb)
|
||||
recon_video = self.decode(latents, height = height, width = width)
|
||||
|
||||
# losses
|
||||
|
||||
@ -1491,10 +1437,6 @@ class DynamicsWorldModel(Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# hyper connections
|
||||
|
||||
hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
|
||||
|
||||
# can accept raw video if tokenizer is passed in
|
||||
|
||||
self.video_tokenizer = video_tokenizer
|
||||
@ -1635,37 +1577,19 @@ class DynamicsWorldModel(Module):
|
||||
depth = value_head_mlp_depth,
|
||||
)
|
||||
|
||||
# attention
|
||||
# efficient axial space / time transformer
|
||||
|
||||
self.attn_softclamp_value = attn_softclamp_value
|
||||
|
||||
# time rotary embedding
|
||||
|
||||
self.time_rotary = Rotary1D(attn_dim_head)
|
||||
|
||||
# transformer
|
||||
|
||||
layers = []
|
||||
is_time = []
|
||||
|
||||
for i in range(depth):
|
||||
layer_index = i + 1
|
||||
|
||||
is_time_block = divisible_by(layer_index, time_block_every)
|
||||
is_time.append(is_time_block)
|
||||
|
||||
rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity()
|
||||
rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity()
|
||||
|
||||
layers.append(ModuleList([
|
||||
rearrange_to_attend,
|
||||
rearrange_from_attend,
|
||||
hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
|
||||
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
||||
]))
|
||||
|
||||
self.layers = ModuleList(layers)
|
||||
self.is_time = is_time
|
||||
self.transformer = AxialSpaceTimeTransformer(
|
||||
dim = dim,
|
||||
depth = depth,
|
||||
attn_dim_head = attn_dim_head,
|
||||
attn_softclamp_value = attn_softclamp_value,
|
||||
attn_kwargs = attn_kwargs,
|
||||
ff_kwargs = ff_kwargs,
|
||||
num_residual_streams = num_residual_streams,
|
||||
num_special_spatial_tokens = num_agents,
|
||||
final_norm = False
|
||||
)
|
||||
|
||||
# zero
|
||||
|
||||
@ -2046,56 +1970,9 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
||||
|
||||
# attend functions for space and time
|
||||
|
||||
seq_len = tokens.shape[1]
|
||||
|
||||
use_flex = exists(flex_attention) and tokens.is_cuda
|
||||
|
||||
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device)
|
||||
|
||||
space_seq_len = (
|
||||
+ 1 # signal + step
|
||||
+ num_action_tokens # past action tokens - todo: account for multi-agent
|
||||
+ num_reward_tokens # maybe allow world model being fine-tuned in phase 3 to see rewards as state
|
||||
+ self.num_agents # action / agent tokens
|
||||
+ self.num_register_tokens
|
||||
+ num_spatial_tokens
|
||||
)
|
||||
|
||||
space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = self.num_agents, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
||||
|
||||
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
||||
|
||||
# rotary
|
||||
|
||||
rotary_pos_emb = self.time_rotary(time)
|
||||
|
||||
# attention
|
||||
|
||||
tokens = self.expand_streams(tokens)
|
||||
|
||||
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
||||
|
||||
tokens = pre_attn_rearrange(tokens)
|
||||
|
||||
# when is a axial time attention block, should be causal
|
||||
|
||||
attend_fn = time_attend if layer_is_time else space_attend
|
||||
|
||||
layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
|
||||
|
||||
# attention layer
|
||||
|
||||
tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens
|
||||
|
||||
tokens = post_attn_rearrange(tokens)
|
||||
|
||||
# feedforward layer
|
||||
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
tokens = self.reduce_streams(tokens)
|
||||
tokens = self.transformer(tokens)
|
||||
|
||||
# unpack
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.23"
|
||||
version = "0.0.24"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user