bring in hyper connections

This commit is contained in:
lucidrains 2025-10-11 06:52:57 -07:00
parent 5df3e69583
commit 563b269f8a
3 changed files with 46 additions and 17 deletions

View File

@ -17,6 +17,8 @@ from torchvision.models import VGG16_Weights
from x_mlps_pytorch.normed_mlp import create_mlp
from x_mlps_pytorch.ensemble import Ensemble
from hyper_connections import get_init_and_expand_reduce_stream_functions
from assoc_scan import AssocScan
# ein related
@ -941,7 +943,8 @@ class VideoTokenizer(Module):
rope_min_freq = 1.,
rope_max_freq = 10000.,
rope_p_zero_freqs = 0.
)
),
num_residual_streams = 1
):
super().__init__()
@ -953,6 +956,10 @@ class VideoTokenizer(Module):
self.num_latent_tokens = num_latent_tokens
self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2)
# hyper connections
hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
# mae masking - Kaiming He paper from long ago
self.per_image_patch_mask_prob = per_image_patch_mask_prob
@ -991,8 +998,8 @@ class VideoTokenizer(Module):
for _ in range(encoder_depth):
encoder_layers.append(ModuleList([
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
SwiGLUFeedforward(dim = dim, **ff_kwargs)
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs)),
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
]))
self.encoder_layers = ModuleList(encoder_layers)
@ -1025,8 +1032,8 @@ class VideoTokenizer(Module):
for _ in range(decoder_depth):
decoder_layers.append(ModuleList([
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
SwiGLUFeedforward(dim = dim, **ff_kwargs)
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs)),
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
]))
self.decoder_layers = ModuleList(decoder_layers)
@ -1129,10 +1136,14 @@ class VideoTokenizer(Module):
# decoder attention
for attn, ff in self.decoder_layers:
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
tokens = self.expand_streams(tokens)
tokens = ff(tokens) + tokens
for attn, ff in self.decoder_layers:
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn)
tokens = ff(tokens)
tokens = self.reduce_streams(tokens)
tokens = self.decoder_norm(tokens)
@ -1227,9 +1238,13 @@ class VideoTokenizer(Module):
# encoder
tokens = self.expand_streams(tokens)
for attn, ff in self.encoder_layers:
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn) + tokens
tokens = ff(tokens) + tokens
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn)
tokens = ff(tokens)
tokens = self.reduce_streams(tokens)
tokens = self.encoder_norm(tokens)
@ -1304,10 +1319,15 @@ class DynamicsWorldModel(Module):
reward_loss_weight = 0.1,
value_head_mlp_depth = 3,
policy_head_mlp_depth = 3,
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
num_residual_streams = 1
):
super().__init__()
# hyper connections
hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
# can accept raw video if tokenizer is passed in
self.video_tokenizer = video_tokenizer
@ -1467,8 +1487,8 @@ class DynamicsWorldModel(Module):
layers.append(ModuleList([
rearrange_to_attend,
rearrange_from_attend,
Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs),
SwiGLUFeedforward(dim = dim, **ff_kwargs)
hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
]))
self.layers = ModuleList(layers)
@ -1856,6 +1876,8 @@ class DynamicsWorldModel(Module):
# attention
tokens = self.expand_streams(tokens)
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
tokens = pre_attn_rearrange(tokens)
@ -1876,6 +1898,8 @@ class DynamicsWorldModel(Module):
tokens = ff(tokens) + tokens
tokens = self.reduce_streams(tokens)
# unpack
flow_token, space_tokens, register_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
@ -2011,7 +2035,7 @@ class Dreamer(Module):
def __init__(
self,
video_tokenizer: VideoTokenizer,
dynamics_model: DynamicsModel,
dynamics_model: DynamicsWorldModel,
discount_factor = 0.997
):
super().__init__()

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.11"
version = "0.0.12"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
@ -31,6 +31,7 @@ dependencies = [
"einx>=0.3.0",
"einops>=0.8.1",
"hl-gauss-pytorch",
"hyper-connections>=0.2.1",
"torch>=2.4",
"torchvision",
"x-mlps-pytorch>=0.0.29"

View File

@ -10,6 +10,7 @@ import torch
@param('num_spatial_tokens', (2, 8))
@param('signal_and_step_passed_in', (False, True))
@param('condition_on_actions', (False, True))
@param('num_residual_streams', (1, 4))
@param('add_reward_embed_to_agent_token', (False, True))
def test_e2e(
pred_orig_latent,
@ -20,6 +21,7 @@ def test_e2e(
num_spatial_tokens,
signal_and_step_passed_in,
condition_on_actions,
num_residual_streams,
add_reward_embed_to_agent_token
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
@ -31,7 +33,8 @@ def test_e2e(
dim_latent = 16,
patch_size = 32,
attn_dim_head = 16,
num_latent_tokens = 4
num_latent_tokens = 4,
num_residual_streams = num_residual_streams
)
video = torch.randn(2, 3, 4, 256, 256)
@ -64,7 +67,8 @@ def test_e2e(
query_heads = query_heads,
),
prob_no_shortcut_train = prob_no_shortcut_train,
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
num_residual_streams = num_residual_streams
)
signal_levels = step_sizes_log2 = None