bring in hyper connections
This commit is contained in:
parent
5df3e69583
commit
563b269f8a
@ -17,6 +17,8 @@ from torchvision.models import VGG16_Weights
|
|||||||
from x_mlps_pytorch.normed_mlp import create_mlp
|
from x_mlps_pytorch.normed_mlp import create_mlp
|
||||||
from x_mlps_pytorch.ensemble import Ensemble
|
from x_mlps_pytorch.ensemble import Ensemble
|
||||||
|
|
||||||
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
||||||
|
|
||||||
from assoc_scan import AssocScan
|
from assoc_scan import AssocScan
|
||||||
|
|
||||||
# ein related
|
# ein related
|
||||||
@ -941,7 +943,8 @@ class VideoTokenizer(Module):
|
|||||||
rope_min_freq = 1.,
|
rope_min_freq = 1.,
|
||||||
rope_max_freq = 10000.,
|
rope_max_freq = 10000.,
|
||||||
rope_p_zero_freqs = 0.
|
rope_p_zero_freqs = 0.
|
||||||
)
|
),
|
||||||
|
num_residual_streams = 1
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -953,6 +956,10 @@ class VideoTokenizer(Module):
|
|||||||
self.num_latent_tokens = num_latent_tokens
|
self.num_latent_tokens = num_latent_tokens
|
||||||
self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2)
|
self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2)
|
||||||
|
|
||||||
|
# hyper connections
|
||||||
|
|
||||||
|
hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
|
||||||
|
|
||||||
# mae masking - Kaiming He paper from long ago
|
# mae masking - Kaiming He paper from long ago
|
||||||
|
|
||||||
self.per_image_patch_mask_prob = per_image_patch_mask_prob
|
self.per_image_patch_mask_prob = per_image_patch_mask_prob
|
||||||
@ -991,8 +998,8 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
for _ in range(encoder_depth):
|
for _ in range(encoder_depth):
|
||||||
encoder_layers.append(ModuleList([
|
encoder_layers.append(ModuleList([
|
||||||
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
|
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs)),
|
||||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
||||||
]))
|
]))
|
||||||
|
|
||||||
self.encoder_layers = ModuleList(encoder_layers)
|
self.encoder_layers = ModuleList(encoder_layers)
|
||||||
@ -1025,8 +1032,8 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
for _ in range(decoder_depth):
|
for _ in range(decoder_depth):
|
||||||
decoder_layers.append(ModuleList([
|
decoder_layers.append(ModuleList([
|
||||||
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
|
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs)),
|
||||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
||||||
]))
|
]))
|
||||||
|
|
||||||
self.decoder_layers = ModuleList(decoder_layers)
|
self.decoder_layers = ModuleList(decoder_layers)
|
||||||
@ -1129,10 +1136,14 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
# decoder attention
|
# decoder attention
|
||||||
|
|
||||||
for attn, ff in self.decoder_layers:
|
tokens = self.expand_streams(tokens)
|
||||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
|
|
||||||
|
|
||||||
tokens = ff(tokens) + tokens
|
for attn, ff in self.decoder_layers:
|
||||||
|
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn)
|
||||||
|
|
||||||
|
tokens = ff(tokens)
|
||||||
|
|
||||||
|
tokens = self.reduce_streams(tokens)
|
||||||
|
|
||||||
tokens = self.decoder_norm(tokens)
|
tokens = self.decoder_norm(tokens)
|
||||||
|
|
||||||
@ -1227,9 +1238,13 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
# encoder
|
# encoder
|
||||||
|
|
||||||
|
tokens = self.expand_streams(tokens)
|
||||||
|
|
||||||
for attn, ff in self.encoder_layers:
|
for attn, ff in self.encoder_layers:
|
||||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn) + tokens
|
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn)
|
||||||
tokens = ff(tokens) + tokens
|
tokens = ff(tokens)
|
||||||
|
|
||||||
|
tokens = self.reduce_streams(tokens)
|
||||||
|
|
||||||
tokens = self.encoder_norm(tokens)
|
tokens = self.encoder_norm(tokens)
|
||||||
|
|
||||||
@ -1304,10 +1319,15 @@ class DynamicsWorldModel(Module):
|
|||||||
reward_loss_weight = 0.1,
|
reward_loss_weight = 0.1,
|
||||||
value_head_mlp_depth = 3,
|
value_head_mlp_depth = 3,
|
||||||
policy_head_mlp_depth = 3,
|
policy_head_mlp_depth = 3,
|
||||||
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||||
|
num_residual_streams = 1
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# hyper connections
|
||||||
|
|
||||||
|
hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
|
||||||
|
|
||||||
# can accept raw video if tokenizer is passed in
|
# can accept raw video if tokenizer is passed in
|
||||||
|
|
||||||
self.video_tokenizer = video_tokenizer
|
self.video_tokenizer = video_tokenizer
|
||||||
@ -1467,8 +1487,8 @@ class DynamicsWorldModel(Module):
|
|||||||
layers.append(ModuleList([
|
layers.append(ModuleList([
|
||||||
rearrange_to_attend,
|
rearrange_to_attend,
|
||||||
rearrange_from_attend,
|
rearrange_from_attend,
|
||||||
Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs),
|
hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
|
||||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
||||||
]))
|
]))
|
||||||
|
|
||||||
self.layers = ModuleList(layers)
|
self.layers = ModuleList(layers)
|
||||||
@ -1856,6 +1876,8 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
|
tokens = self.expand_streams(tokens)
|
||||||
|
|
||||||
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
||||||
|
|
||||||
tokens = pre_attn_rearrange(tokens)
|
tokens = pre_attn_rearrange(tokens)
|
||||||
@ -1876,6 +1898,8 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
tokens = ff(tokens) + tokens
|
tokens = ff(tokens) + tokens
|
||||||
|
|
||||||
|
tokens = self.reduce_streams(tokens)
|
||||||
|
|
||||||
# unpack
|
# unpack
|
||||||
|
|
||||||
flow_token, space_tokens, register_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
flow_token, space_tokens, register_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||||
@ -2011,7 +2035,7 @@ class Dreamer(Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
video_tokenizer: VideoTokenizer,
|
video_tokenizer: VideoTokenizer,
|
||||||
dynamics_model: DynamicsModel,
|
dynamics_model: DynamicsWorldModel,
|
||||||
discount_factor = 0.997
|
discount_factor = 0.997
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.11"
|
version = "0.0.12"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
@ -31,6 +31,7 @@ dependencies = [
|
|||||||
"einx>=0.3.0",
|
"einx>=0.3.0",
|
||||||
"einops>=0.8.1",
|
"einops>=0.8.1",
|
||||||
"hl-gauss-pytorch",
|
"hl-gauss-pytorch",
|
||||||
|
"hyper-connections>=0.2.1",
|
||||||
"torch>=2.4",
|
"torch>=2.4",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"x-mlps-pytorch>=0.0.29"
|
"x-mlps-pytorch>=0.0.29"
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import torch
|
|||||||
@param('num_spatial_tokens', (2, 8))
|
@param('num_spatial_tokens', (2, 8))
|
||||||
@param('signal_and_step_passed_in', (False, True))
|
@param('signal_and_step_passed_in', (False, True))
|
||||||
@param('condition_on_actions', (False, True))
|
@param('condition_on_actions', (False, True))
|
||||||
|
@param('num_residual_streams', (1, 4))
|
||||||
@param('add_reward_embed_to_agent_token', (False, True))
|
@param('add_reward_embed_to_agent_token', (False, True))
|
||||||
def test_e2e(
|
def test_e2e(
|
||||||
pred_orig_latent,
|
pred_orig_latent,
|
||||||
@ -20,6 +21,7 @@ def test_e2e(
|
|||||||
num_spatial_tokens,
|
num_spatial_tokens,
|
||||||
signal_and_step_passed_in,
|
signal_and_step_passed_in,
|
||||||
condition_on_actions,
|
condition_on_actions,
|
||||||
|
num_residual_streams,
|
||||||
add_reward_embed_to_agent_token
|
add_reward_embed_to_agent_token
|
||||||
):
|
):
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||||
@ -31,7 +33,8 @@ def test_e2e(
|
|||||||
dim_latent = 16,
|
dim_latent = 16,
|
||||||
patch_size = 32,
|
patch_size = 32,
|
||||||
attn_dim_head = 16,
|
attn_dim_head = 16,
|
||||||
num_latent_tokens = 4
|
num_latent_tokens = 4,
|
||||||
|
num_residual_streams = num_residual_streams
|
||||||
)
|
)
|
||||||
|
|
||||||
video = torch.randn(2, 3, 4, 256, 256)
|
video = torch.randn(2, 3, 4, 256, 256)
|
||||||
@ -64,7 +67,8 @@ def test_e2e(
|
|||||||
query_heads = query_heads,
|
query_heads = query_heads,
|
||||||
),
|
),
|
||||||
prob_no_shortcut_train = prob_no_shortcut_train,
|
prob_no_shortcut_train = prob_no_shortcut_train,
|
||||||
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
|
||||||
|
num_residual_streams = num_residual_streams
|
||||||
)
|
)
|
||||||
|
|
||||||
signal_levels = step_sizes_log2 = None
|
signal_levels = step_sizes_log2 = None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user