diff --git a/README.md b/README.md index 0c5ba3e..439fe3d 100644 --- a/README.md +++ b/README.md @@ -19,12 +19,3 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v url = {https://arxiv.org/abs/2509.24527}, } ``` - -```bibtex -@misc{xiong2025ndrope, - author = {Jerry Xiong}, - title = {On n-dimensional rotary positional embeddings}, - year = {2025}, - url = {https://jerryxio.ng/posts/nd-rope/} -} -``` diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 062af52..8384384 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -647,64 +647,7 @@ def calc_gae( return returns -# golden gate rotary - Jerry Xiong, PhD student at UIUC -# https://jerryxio.ng/posts/nd-rope/ - -def _phi(m): - x = 2. - for _ in range(10): - x = (1. + x) ** (1. / (m + 1.)) - return x - -def make_directions(n, d): - g = _phi(d) - alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64) - i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1) - z = torch.fmod(i * alpha, 1.0) - directions = torch.erfinv(2.0 * z - 1.0) - directions = l2norm(directions) - return directions.float() - -class GoldenGateRoPENd(Module): - def __init__( - self, - dim_pos, - heads, - dim_head, - rope_min_freq = 1., - rope_max_freq = 10000., - rope_p_zero_freqs = 0., # proportion of frequencies set to 0 - ): - super().__init__() - assert divisible_by(dim_head, 2) - - n_freqs = dim_head // 2 - n_zero_freqs = round(rope_p_zero_freqs * n_freqs) - - omega = cat(( - torch.zeros(n_zero_freqs), - rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs), - )) - - directions = make_directions(heads * n_freqs, dim_pos) - directions = rearrange(directions, '(h f) p -> h f p', h = heads) - - omega_expanded = rearrange(omega, 'f -> f 1') - self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p) - - def forward( - self, - pos # (b n p) - ): - - freqs = rearrange(self.freqs, 'h f p -> h 1 f p') - positions = rearrange(pos.float(), 'n p -> 1 n 1 p') - - # thetas for freqs and positions (batch, head, seq, freq) - - theta = reduce(freqs * positions, 'h n f p -> h n f', 'sum') - - return cat((theta, theta), dim = -1) +# rotary embeddings for time class Rotary1D(Module): def __init__( @@ -1070,6 +1013,123 @@ class SwiGLUFeedforward(Module): return self.proj_out(x) +# axial space time transformer + +class AxialSpaceTimeTransformer(Module): + def __init__( + self, + dim, + depth, + attn_dim_head = 64, + attn_softclamp_value = 50., + time_block_every = 4, + attn_kwargs: dict = dict(), + ff_kwargs: dict = dict(), + num_residual_streams = 1, + num_special_spatial_tokens = 1, + special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything) + final_norm = True + ): + super().__init__() + + # hyper connections + + hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim) + + # attention + + self.attn_softclamp_value = attn_softclamp_value + + # attention masking + + self.special_attend_only_itself = special_attend_only_itself + + # time rotary embedding + + self.time_rotary = Rotary1D(attn_dim_head) + + # transformer + + layers = [] + is_time = [] + + for i in range(depth): + layer_index = i + 1 + + is_time_block = divisible_by(layer_index, time_block_every) + is_time.append(is_time_block) + + rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity() + rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity() + + layers.append(ModuleList([ + rearrange_to_attend, + rearrange_from_attend, + hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)), + hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs)) + ])) + + self.layers = ModuleList(layers) + self.is_time = is_time + + # final norm + + self.final_norm = nn.RMSNorm(dim) if final_norm else nn.Identity() + + # special tokens + + self.num_special_spatial_tokens = num_special_spatial_tokens + + def forward( + self, + tokens # (b t s d) + ): + batch, time, space_seq_len, _, device = *tokens.shape, tokens.device + + assert tokens.ndim == 4 + + # attend functions for space and time + + use_flex = exists(flex_attention) and tokens.is_cuda + + attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device) + + space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = self.num_special_spatial_tokens, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality + + time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs) + + # rotary + + rotary_pos_emb = self.time_rotary(time) + + # attention + + tokens = self.expand_streams(tokens) + + for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time): + + tokens = pre_attn_rearrange(tokens) + + # when is a axial time attention block, should be causal + + attend_fn = time_attend if layer_is_time else space_attend + + layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None + + # attention layer + + tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens + + tokens = post_attn_rearrange(tokens) + + # feedforward layer + + tokens = ff(tokens) + tokens + + tokens = self.reduce_streams(tokens) + + return self.final_norm(tokens) + # video tokenizer class VideoTokenizer(Module): @@ -1083,6 +1143,7 @@ class VideoTokenizer(Module): num_latent_tokens = 4, encoder_depth = 4, decoder_depth = 4, + time_block_every = 4, attn_kwargs: dict = dict(), attn_dim_head = 64, attn_heads = 8, @@ -1133,32 +1194,19 @@ class VideoTokenizer(Module): Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size), ) - # 3d rotations + # encoder space / time transformer - self.spacetime_rotary = GoldenGateRoPENd( - dim_pos = 3, - heads = attn_heads, - dim_head = attn_dim_head, - **nd_rotary_kwargs + self.encoder_transformer = AxialSpaceTimeTransformer( + dim = dim, + depth = encoder_depth, + attn_dim_head = attn_dim_head, + attn_softclamp_value = attn_softclamp_value, + time_block_every = time_block_every, + num_special_spatial_tokens = num_latent_tokens, + num_residual_streams = num_residual_streams, + final_norm = True ) - # attention related - - self.attn_softclamp_value = attn_softclamp_value - - # encoder - - encoder_layers = [] - - for _ in range(encoder_depth): - encoder_layers.append(ModuleList([ - hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs)), - hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs)) - ])) - - self.encoder_layers = ModuleList(encoder_layers) - self.encoder_norm = RMSNorm(dim) - # latents self.encoded_to_latents = Sequential( @@ -1182,16 +1230,18 @@ class VideoTokenizer(Module): depth = decoder_pos_mlp_depth, ) - decoder_layers = [] + # decoder transformer - for _ in range(decoder_depth): - decoder_layers.append(ModuleList([ - hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs)), - hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs)) - ])) - - self.decoder_layers = ModuleList(decoder_layers) - self.decoder_norm = RMSNorm(dim) + self.decoder_transformer = AxialSpaceTimeTransformer( + dim = dim, + depth = decoder_depth, + attn_dim_head = attn_dim_head, + attn_softclamp_value = attn_softclamp_value, + time_block_every = time_block_every, + num_special_spatial_tokens = num_latent_tokens, + num_residual_streams = num_residual_streams, + final_norm = True + ) # loss related @@ -1215,36 +1265,11 @@ class VideoTokenizer(Module): self.eval() return self.forward(video, return_latents = True) - def get_rotary_pos_emb( - self, - time, - num_patch_height, - num_patch_width - ): - device = self.device - - positions = stack(torch.meshgrid( - arange(time, device = device), - arange(num_patch_height, device = device), - arange(num_patch_width, device = device) - ), dim = -1) - - positions = rearrange(positions, 't h w p -> t (h w) p') - - # give the latents an out of bounds position and assume the network will figure it out - - positions = pad_at_dim(positions, (0, self.num_latent_tokens), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated - - positions = rearrange(positions, 't hw p -> (t hw) p') - - return self.spacetime_rotary(positions) - def decode( self, latents, # (b t n d) height = None, width = None, - rotary_pos_emb = None ): # (b c t h w) height = default(height, self.image_height) @@ -1259,9 +1284,6 @@ class VideoTokenizer(Module): num_patch_height = height // self.patch_size num_patch_width = width // self.patch_size - if not exists(rotary_pos_emb): - rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width) - # latents to tokens latent_tokens = self.latents_to_decoder(latents) @@ -1278,43 +1300,9 @@ class VideoTokenizer(Module): tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d') - space_seq_len = tokens.shape[-2] - - # pack time - - tokens, inverse_pack_time = pack_one(tokens, 'b * d') - - seq_len = tokens.shape[-2] - - # decoder attend - - decoder_attend_fn = get_attend_fn( - use_flex, - seq_len, seq_len, - causal = True, - causal_block_size = space_seq_len, - softclamp_value = self.attn_softclamp_value, - block_size_per_special = space_seq_len, - num_special_tokens = self.num_latent_tokens, - special_attend_only_itself = True # different than encoder - ) - # decoder attention - tokens = self.expand_streams(tokens) - - for attn, ff in self.decoder_layers: - tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) - - tokens = ff(tokens) - - tokens = self.reduce_streams(tokens) - - tokens = self.decoder_norm(tokens) - - # unpack time - - tokens = inverse_pack_time(tokens) + tokens = self.decoder_transformer(tokens) # unpack latents @@ -1346,10 +1334,6 @@ class VideoTokenizer(Module): num_patch_height, num_patch_width, _ = tokens.shape[-3:] - # rotary positions - - rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width) - # masking mask_patches = default(mask_patches, self.training) @@ -1374,50 +1358,12 @@ class VideoTokenizer(Module): tokens, packed_latent_shape = pack((tokens, latents), 'b t * d') - space_seq_len = tokens.shape[-2] + # encoder attention - # pack time - - tokens, inverse_pack_time = pack_one(tokens, 'b * d') - - seq_len = tokens.shape[1] - - # attend hyper parameters - - use_flex = tokens.is_cuda and exists(flex_attention) - - # encoder attend - - # modality can only attend to itself while latents can attend to everything - # similar to agent token in dynamics model - - encoder_attend_fn = get_attend_fn( - use_flex, - seq_len, seq_len, - causal = True, - causal_block_size = space_seq_len, - softclamp_value = self.attn_softclamp_value, - block_size_per_special = space_seq_len, - num_special_tokens = self.num_latent_tokens, - special_attend_only_itself = False # different than decoder - ) - - # encoder - - tokens = self.expand_streams(tokens) - - for attn, ff in self.encoder_layers: - tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn) - tokens = ff(tokens) - - tokens = self.reduce_streams(tokens) - - tokens = self.encoder_norm(tokens) + tokens = self.encoder_transformer(tokens) # latent bottleneck - tokens = inverse_pack_time(tokens) - tokens, latents = unpack(tokens, packed_latent_shape, 'b t * d') latents = self.encoded_to_latents(latents) @@ -1425,7 +1371,7 @@ class VideoTokenizer(Module): if return_latents: return latents - recon_video = self.decode(latents, height = height, width = width, rotary_pos_emb = rotary_pos_emb) + recon_video = self.decode(latents, height = height, width = width) # losses @@ -1491,10 +1437,6 @@ class DynamicsWorldModel(Module): ): super().__init__() - # hyper connections - - hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim) - # can accept raw video if tokenizer is passed in self.video_tokenizer = video_tokenizer @@ -1635,37 +1577,19 @@ class DynamicsWorldModel(Module): depth = value_head_mlp_depth, ) - # attention + # efficient axial space / time transformer - self.attn_softclamp_value = attn_softclamp_value - - # time rotary embedding - - self.time_rotary = Rotary1D(attn_dim_head) - - # transformer - - layers = [] - is_time = [] - - for i in range(depth): - layer_index = i + 1 - - is_time_block = divisible_by(layer_index, time_block_every) - is_time.append(is_time_block) - - rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity() - rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity() - - layers.append(ModuleList([ - rearrange_to_attend, - rearrange_from_attend, - hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)), - hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs)) - ])) - - self.layers = ModuleList(layers) - self.is_time = is_time + self.transformer = AxialSpaceTimeTransformer( + dim = dim, + depth = depth, + attn_dim_head = attn_dim_head, + attn_softclamp_value = attn_softclamp_value, + attn_kwargs = attn_kwargs, + ff_kwargs = ff_kwargs, + num_residual_streams = num_residual_streams, + num_special_spatial_tokens = num_agents, + final_norm = False + ) # zero @@ -2046,56 +1970,9 @@ class DynamicsWorldModel(Module): tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d') - # attend functions for space and time - - seq_len = tokens.shape[1] - - use_flex = exists(flex_attention) and tokens.is_cuda - - attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device) - - space_seq_len = ( - + 1 # signal + step - + num_action_tokens # past action tokens - todo: account for multi-agent - + num_reward_tokens # maybe allow world model being fine-tuned in phase 3 to see rewards as state - + self.num_agents # action / agent tokens - + self.num_register_tokens - + num_spatial_tokens - ) - - space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = self.num_agents, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality - - time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs) - - # rotary - - rotary_pos_emb = self.time_rotary(time) - # attention - tokens = self.expand_streams(tokens) - - for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time): - - tokens = pre_attn_rearrange(tokens) - - # when is a axial time attention block, should be causal - - attend_fn = time_attend if layer_is_time else space_attend - - layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None - - # attention layer - - tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens - - tokens = post_attn_rearrange(tokens) - - # feedforward layer - - tokens = ff(tokens) + tokens - - tokens = self.reduce_streams(tokens) + tokens = self.transformer(tokens) # unpack diff --git a/pyproject.toml b/pyproject.toml index 32e4200..540bf3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.23" +version = "0.0.24" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }