bring in hyper connections
This commit is contained in:
parent
5df3e69583
commit
563b269f8a
@ -17,6 +17,8 @@ from torchvision.models import VGG16_Weights
|
||||
from x_mlps_pytorch.normed_mlp import create_mlp
|
||||
from x_mlps_pytorch.ensemble import Ensemble
|
||||
|
||||
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
||||
|
||||
from assoc_scan import AssocScan
|
||||
|
||||
# ein related
|
||||
@ -941,7 +943,8 @@ class VideoTokenizer(Module):
|
||||
rope_min_freq = 1.,
|
||||
rope_max_freq = 10000.,
|
||||
rope_p_zero_freqs = 0.
|
||||
)
|
||||
),
|
||||
num_residual_streams = 1
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -953,6 +956,10 @@ class VideoTokenizer(Module):
|
||||
self.num_latent_tokens = num_latent_tokens
|
||||
self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2)
|
||||
|
||||
# hyper connections
|
||||
|
||||
hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
|
||||
|
||||
# mae masking - Kaiming He paper from long ago
|
||||
|
||||
self.per_image_patch_mask_prob = per_image_patch_mask_prob
|
||||
@ -991,8 +998,8 @@ class VideoTokenizer(Module):
|
||||
|
||||
for _ in range(encoder_depth):
|
||||
encoder_layers.append(ModuleList([
|
||||
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
|
||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs)),
|
||||
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
||||
]))
|
||||
|
||||
self.encoder_layers = ModuleList(encoder_layers)
|
||||
@ -1025,8 +1032,8 @@ class VideoTokenizer(Module):
|
||||
|
||||
for _ in range(decoder_depth):
|
||||
decoder_layers.append(ModuleList([
|
||||
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
|
||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs)),
|
||||
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
||||
]))
|
||||
|
||||
self.decoder_layers = ModuleList(decoder_layers)
|
||||
@ -1129,10 +1136,14 @@ class VideoTokenizer(Module):
|
||||
|
||||
# decoder attention
|
||||
|
||||
for attn, ff in self.decoder_layers:
|
||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
|
||||
tokens = self.expand_streams(tokens)
|
||||
|
||||
tokens = ff(tokens) + tokens
|
||||
for attn, ff in self.decoder_layers:
|
||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn)
|
||||
|
||||
tokens = ff(tokens)
|
||||
|
||||
tokens = self.reduce_streams(tokens)
|
||||
|
||||
tokens = self.decoder_norm(tokens)
|
||||
|
||||
@ -1227,9 +1238,13 @@ class VideoTokenizer(Module):
|
||||
|
||||
# encoder
|
||||
|
||||
tokens = self.expand_streams(tokens)
|
||||
|
||||
for attn, ff in self.encoder_layers:
|
||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn) + tokens
|
||||
tokens = ff(tokens) + tokens
|
||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn)
|
||||
tokens = ff(tokens)
|
||||
|
||||
tokens = self.reduce_streams(tokens)
|
||||
|
||||
tokens = self.encoder_norm(tokens)
|
||||
|
||||
@ -1304,10 +1319,15 @@ class DynamicsWorldModel(Module):
|
||||
reward_loss_weight = 0.1,
|
||||
value_head_mlp_depth = 3,
|
||||
policy_head_mlp_depth = 3,
|
||||
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||
num_residual_streams = 1
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# hyper connections
|
||||
|
||||
hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
|
||||
|
||||
# can accept raw video if tokenizer is passed in
|
||||
|
||||
self.video_tokenizer = video_tokenizer
|
||||
@ -1467,8 +1487,8 @@ class DynamicsWorldModel(Module):
|
||||
layers.append(ModuleList([
|
||||
rearrange_to_attend,
|
||||
rearrange_from_attend,
|
||||
Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs),
|
||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||
hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
|
||||
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
||||
]))
|
||||
|
||||
self.layers = ModuleList(layers)
|
||||
@ -1856,6 +1876,8 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# attention
|
||||
|
||||
tokens = self.expand_streams(tokens)
|
||||
|
||||
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
||||
|
||||
tokens = pre_attn_rearrange(tokens)
|
||||
@ -1876,6 +1898,8 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
tokens = self.reduce_streams(tokens)
|
||||
|
||||
# unpack
|
||||
|
||||
flow_token, space_tokens, register_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||
@ -2011,7 +2035,7 @@ class Dreamer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
video_tokenizer: VideoTokenizer,
|
||||
dynamics_model: DynamicsModel,
|
||||
dynamics_model: DynamicsWorldModel,
|
||||
discount_factor = 0.997
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.11"
|
||||
version = "0.0.12"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
@ -31,6 +31,7 @@ dependencies = [
|
||||
"einx>=0.3.0",
|
||||
"einops>=0.8.1",
|
||||
"hl-gauss-pytorch",
|
||||
"hyper-connections>=0.2.1",
|
||||
"torch>=2.4",
|
||||
"torchvision",
|
||||
"x-mlps-pytorch>=0.0.29"
|
||||
|
||||
@ -10,6 +10,7 @@ import torch
|
||||
@param('num_spatial_tokens', (2, 8))
|
||||
@param('signal_and_step_passed_in', (False, True))
|
||||
@param('condition_on_actions', (False, True))
|
||||
@param('num_residual_streams', (1, 4))
|
||||
@param('add_reward_embed_to_agent_token', (False, True))
|
||||
def test_e2e(
|
||||
pred_orig_latent,
|
||||
@ -20,6 +21,7 @@ def test_e2e(
|
||||
num_spatial_tokens,
|
||||
signal_and_step_passed_in,
|
||||
condition_on_actions,
|
||||
num_residual_streams,
|
||||
add_reward_embed_to_agent_token
|
||||
):
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
@ -31,7 +33,8 @@ def test_e2e(
|
||||
dim_latent = 16,
|
||||
patch_size = 32,
|
||||
attn_dim_head = 16,
|
||||
num_latent_tokens = 4
|
||||
num_latent_tokens = 4,
|
||||
num_residual_streams = num_residual_streams
|
||||
)
|
||||
|
||||
video = torch.randn(2, 3, 4, 256, 256)
|
||||
@ -64,7 +67,8 @@ def test_e2e(
|
||||
query_heads = query_heads,
|
||||
),
|
||||
prob_no_shortcut_train = prob_no_shortcut_train,
|
||||
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
||||
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
|
||||
num_residual_streams = num_residual_streams
|
||||
)
|
||||
|
||||
signal_levels = step_sizes_log2 = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user