bring in hyper connections

This commit is contained in:
lucidrains 2025-10-11 06:52:57 -07:00
parent 5df3e69583
commit 563b269f8a
3 changed files with 46 additions and 17 deletions

View File

@ -17,6 +17,8 @@ from torchvision.models import VGG16_Weights
from x_mlps_pytorch.normed_mlp import create_mlp from x_mlps_pytorch.normed_mlp import create_mlp
from x_mlps_pytorch.ensemble import Ensemble from x_mlps_pytorch.ensemble import Ensemble
from hyper_connections import get_init_and_expand_reduce_stream_functions
from assoc_scan import AssocScan from assoc_scan import AssocScan
# ein related # ein related
@ -941,7 +943,8 @@ class VideoTokenizer(Module):
rope_min_freq = 1., rope_min_freq = 1.,
rope_max_freq = 10000., rope_max_freq = 10000.,
rope_p_zero_freqs = 0. rope_p_zero_freqs = 0.
) ),
num_residual_streams = 1
): ):
super().__init__() super().__init__()
@ -953,6 +956,10 @@ class VideoTokenizer(Module):
self.num_latent_tokens = num_latent_tokens self.num_latent_tokens = num_latent_tokens
self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2) self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2)
# hyper connections
hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
# mae masking - Kaiming He paper from long ago # mae masking - Kaiming He paper from long ago
self.per_image_patch_mask_prob = per_image_patch_mask_prob self.per_image_patch_mask_prob = per_image_patch_mask_prob
@ -991,8 +998,8 @@ class VideoTokenizer(Module):
for _ in range(encoder_depth): for _ in range(encoder_depth):
encoder_layers.append(ModuleList([ encoder_layers.append(ModuleList([
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs), hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs)),
SwiGLUFeedforward(dim = dim, **ff_kwargs) hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
])) ]))
self.encoder_layers = ModuleList(encoder_layers) self.encoder_layers = ModuleList(encoder_layers)
@ -1025,8 +1032,8 @@ class VideoTokenizer(Module):
for _ in range(decoder_depth): for _ in range(decoder_depth):
decoder_layers.append(ModuleList([ decoder_layers.append(ModuleList([
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs), hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs)),
SwiGLUFeedforward(dim = dim, **ff_kwargs) hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
])) ]))
self.decoder_layers = ModuleList(decoder_layers) self.decoder_layers = ModuleList(decoder_layers)
@ -1129,10 +1136,14 @@ class VideoTokenizer(Module):
# decoder attention # decoder attention
for attn, ff in self.decoder_layers: tokens = self.expand_streams(tokens)
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
tokens = ff(tokens) + tokens for attn, ff in self.decoder_layers:
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn)
tokens = ff(tokens)
tokens = self.reduce_streams(tokens)
tokens = self.decoder_norm(tokens) tokens = self.decoder_norm(tokens)
@ -1227,9 +1238,13 @@ class VideoTokenizer(Module):
# encoder # encoder
tokens = self.expand_streams(tokens)
for attn, ff in self.encoder_layers: for attn, ff in self.encoder_layers:
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn) + tokens tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn)
tokens = ff(tokens) + tokens tokens = ff(tokens)
tokens = self.reduce_streams(tokens)
tokens = self.encoder_norm(tokens) tokens = self.encoder_norm(tokens)
@ -1304,10 +1319,15 @@ class DynamicsWorldModel(Module):
reward_loss_weight = 0.1, reward_loss_weight = 0.1,
value_head_mlp_depth = 3, value_head_mlp_depth = 3,
policy_head_mlp_depth = 3, policy_head_mlp_depth = 3,
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037 num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
num_residual_streams = 1
): ):
super().__init__() super().__init__()
# hyper connections
hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
# can accept raw video if tokenizer is passed in # can accept raw video if tokenizer is passed in
self.video_tokenizer = video_tokenizer self.video_tokenizer = video_tokenizer
@ -1467,8 +1487,8 @@ class DynamicsWorldModel(Module):
layers.append(ModuleList([ layers.append(ModuleList([
rearrange_to_attend, rearrange_to_attend,
rearrange_from_attend, rearrange_from_attend,
Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs), hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
SwiGLUFeedforward(dim = dim, **ff_kwargs) hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
])) ]))
self.layers = ModuleList(layers) self.layers = ModuleList(layers)
@ -1856,6 +1876,8 @@ class DynamicsWorldModel(Module):
# attention # attention
tokens = self.expand_streams(tokens)
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time): for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
tokens = pre_attn_rearrange(tokens) tokens = pre_attn_rearrange(tokens)
@ -1876,6 +1898,8 @@ class DynamicsWorldModel(Module):
tokens = ff(tokens) + tokens tokens = ff(tokens) + tokens
tokens = self.reduce_streams(tokens)
# unpack # unpack
flow_token, space_tokens, register_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d') flow_token, space_tokens, register_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
@ -2011,7 +2035,7 @@ class Dreamer(Module):
def __init__( def __init__(
self, self,
video_tokenizer: VideoTokenizer, video_tokenizer: VideoTokenizer,
dynamics_model: DynamicsModel, dynamics_model: DynamicsWorldModel,
discount_factor = 0.997 discount_factor = 0.997
): ):
super().__init__() super().__init__()

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.11" version = "0.0.12"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }
@ -31,6 +31,7 @@ dependencies = [
"einx>=0.3.0", "einx>=0.3.0",
"einops>=0.8.1", "einops>=0.8.1",
"hl-gauss-pytorch", "hl-gauss-pytorch",
"hyper-connections>=0.2.1",
"torch>=2.4", "torch>=2.4",
"torchvision", "torchvision",
"x-mlps-pytorch>=0.0.29" "x-mlps-pytorch>=0.0.29"

View File

@ -10,6 +10,7 @@ import torch
@param('num_spatial_tokens', (2, 8)) @param('num_spatial_tokens', (2, 8))
@param('signal_and_step_passed_in', (False, True)) @param('signal_and_step_passed_in', (False, True))
@param('condition_on_actions', (False, True)) @param('condition_on_actions', (False, True))
@param('num_residual_streams', (1, 4))
@param('add_reward_embed_to_agent_token', (False, True)) @param('add_reward_embed_to_agent_token', (False, True))
def test_e2e( def test_e2e(
pred_orig_latent, pred_orig_latent,
@ -20,6 +21,7 @@ def test_e2e(
num_spatial_tokens, num_spatial_tokens,
signal_and_step_passed_in, signal_and_step_passed_in,
condition_on_actions, condition_on_actions,
num_residual_streams,
add_reward_embed_to_agent_token add_reward_embed_to_agent_token
): ):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
@ -31,7 +33,8 @@ def test_e2e(
dim_latent = 16, dim_latent = 16,
patch_size = 32, patch_size = 32,
attn_dim_head = 16, attn_dim_head = 16,
num_latent_tokens = 4 num_latent_tokens = 4,
num_residual_streams = num_residual_streams
) )
video = torch.randn(2, 3, 4, 256, 256) video = torch.randn(2, 3, 4, 256, 256)
@ -64,7 +67,8 @@ def test_e2e(
query_heads = query_heads, query_heads = query_heads,
), ),
prob_no_shortcut_train = prob_no_shortcut_train, prob_no_shortcut_train = prob_no_shortcut_train,
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
num_residual_streams = num_residual_streams
) )
signal_levels = step_sizes_log2 = None signal_levels = step_sizes_log2 = None