diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 34dacc9..98638d5 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -17,6 +17,8 @@ from torchvision.models import VGG16_Weights from x_mlps_pytorch.normed_mlp import create_mlp from x_mlps_pytorch.ensemble import Ensemble +from hyper_connections import get_init_and_expand_reduce_stream_functions + from assoc_scan import AssocScan # ein related @@ -941,7 +943,8 @@ class VideoTokenizer(Module): rope_min_freq = 1., rope_max_freq = 10000., rope_p_zero_freqs = 0. - ) + ), + num_residual_streams = 1 ): super().__init__() @@ -953,6 +956,10 @@ class VideoTokenizer(Module): self.num_latent_tokens = num_latent_tokens self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2) + # hyper connections + + hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim) + # mae masking - Kaiming He paper from long ago self.per_image_patch_mask_prob = per_image_patch_mask_prob @@ -991,8 +998,8 @@ class VideoTokenizer(Module): for _ in range(encoder_depth): encoder_layers.append(ModuleList([ - Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs), - SwiGLUFeedforward(dim = dim, **ff_kwargs) + hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs)), + hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs)) ])) self.encoder_layers = ModuleList(encoder_layers) @@ -1025,8 +1032,8 @@ class VideoTokenizer(Module): for _ in range(decoder_depth): decoder_layers.append(ModuleList([ - Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs), - SwiGLUFeedforward(dim = dim, **ff_kwargs) + hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs)), + hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs)) ])) self.decoder_layers = ModuleList(decoder_layers) @@ -1129,10 +1136,14 @@ class VideoTokenizer(Module): # decoder attention - for attn, ff in self.decoder_layers: - tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens + tokens = self.expand_streams(tokens) - tokens = ff(tokens) + tokens + for attn, ff in self.decoder_layers: + tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + + tokens = ff(tokens) + + tokens = self.reduce_streams(tokens) tokens = self.decoder_norm(tokens) @@ -1227,9 +1238,13 @@ class VideoTokenizer(Module): # encoder + tokens = self.expand_streams(tokens) + for attn, ff in self.encoder_layers: - tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn) + tokens - tokens = ff(tokens) + tokens + tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn) + tokens = ff(tokens) + + tokens = self.reduce_streams(tokens) tokens = self.encoder_norm(tokens) @@ -1304,10 +1319,15 @@ class DynamicsWorldModel(Module): reward_loss_weight = 0.1, value_head_mlp_depth = 3, policy_head_mlp_depth = 3, - num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037 + num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037 + num_residual_streams = 1 ): super().__init__() + # hyper connections + + hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim) + # can accept raw video if tokenizer is passed in self.video_tokenizer = video_tokenizer @@ -1467,8 +1487,8 @@ class DynamicsWorldModel(Module): layers.append(ModuleList([ rearrange_to_attend, rearrange_from_attend, - Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs), - SwiGLUFeedforward(dim = dim, **ff_kwargs) + hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)), + hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs)) ])) self.layers = ModuleList(layers) @@ -1856,6 +1876,8 @@ class DynamicsWorldModel(Module): # attention + tokens = self.expand_streams(tokens) + for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time): tokens = pre_attn_rearrange(tokens) @@ -1876,6 +1898,8 @@ class DynamicsWorldModel(Module): tokens = ff(tokens) + tokens + tokens = self.reduce_streams(tokens) + # unpack flow_token, space_tokens, register_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d') @@ -2011,7 +2035,7 @@ class Dreamer(Module): def __init__( self, video_tokenizer: VideoTokenizer, - dynamics_model: DynamicsModel, + dynamics_model: DynamicsWorldModel, discount_factor = 0.997 ): super().__init__() diff --git a/pyproject.toml b/pyproject.toml index f6129e5..7b59ac3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.11" +version = "0.0.12" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } @@ -31,6 +31,7 @@ dependencies = [ "einx>=0.3.0", "einops>=0.8.1", "hl-gauss-pytorch", + "hyper-connections>=0.2.1", "torch>=2.4", "torchvision", "x-mlps-pytorch>=0.0.29" diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 1a45a79..0008797 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -10,6 +10,7 @@ import torch @param('num_spatial_tokens', (2, 8)) @param('signal_and_step_passed_in', (False, True)) @param('condition_on_actions', (False, True)) +@param('num_residual_streams', (1, 4)) @param('add_reward_embed_to_agent_token', (False, True)) def test_e2e( pred_orig_latent, @@ -20,6 +21,7 @@ def test_e2e( num_spatial_tokens, signal_and_step_passed_in, condition_on_actions, + num_residual_streams, add_reward_embed_to_agent_token ): from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel @@ -31,7 +33,8 @@ def test_e2e( dim_latent = 16, patch_size = 32, attn_dim_head = 16, - num_latent_tokens = 4 + num_latent_tokens = 4, + num_residual_streams = num_residual_streams ) video = torch.randn(2, 3, 4, 256, 256) @@ -64,7 +67,8 @@ def test_e2e( query_heads = query_heads, ), prob_no_shortcut_train = prob_no_shortcut_train, - add_reward_embed_to_agent_token = add_reward_embed_to_agent_token + add_reward_embed_to_agent_token = add_reward_embed_to_agent_token, + num_residual_streams = num_residual_streams ) signal_levels = step_sizes_log2 = None