diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 15961b9..c353c54 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -9,6 +9,8 @@ import torch.nn.functional as F from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity from torch import cat, stack, arange, tensor, Tensor, is_tensor +from x_mlps_pytorch import create_mlp + from accelerate import Accelerator # ein related @@ -460,7 +462,8 @@ class VideoTokenizer(Module): latents = self.encoded_to_latents(tokens) if return_latents: - return latents + latents = inverse_pack_time(latents) + return latents[..., -1, :] tokens = self.latents_to_decoder(latents) @@ -501,6 +504,8 @@ class DynamicsModel(Module): self, dim, dim_latent, + num_signal_levels = 500, + num_step_sizes = 32, num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction) num_register_tokens = 8, # they claim register tokens led to better temporal consistency depth = 4, @@ -522,6 +527,14 @@ class DynamicsModel(Module): self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2) + # signal and step sizes + + assert divisible_by(dim, 2) + dim_half = dim // 2 + + self.signal_levels_embed = nn.Embedding(num_signal_levels, dim_half) + self.step_sizes_embed = nn.Embedding(num_step_sizes, dim_half) + # they sum all the actions into a single token self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2) @@ -555,19 +568,36 @@ class DynamicsModel(Module): def forward( self, - latents # (b t d) + latents, # (b t d) + signal_levels = None, # (b t) + step_sizes = None # (b t) ): space_tokens = self.latents_to_spatial_tokens(latents) # pack to tokens - # [latent space tokens] [register] [actions / agent] + # [signal + step size embed] [latent space tokens] [register] [actions / agent] registers = repeat(self.register_tokens, 's d -> b t s d', b = latents.shape[0], t = latents.shape[1]) - agent_token = repeat(self.action_learned_embed, 'd -> b t 1 d', b = latents.shape[0], t = latents.shape[1]) + agent_token = repeat(self.action_learned_embed, 'd -> b t d', b = latents.shape[0], t = latents.shape[1]) - tokens, packed_tokens_shape = pack([space_tokens, registers, agent_token], 'b t * d') + # determine signal + step size embed for their diffusion forcing + shortcut + + assert not (exists(signal_levels) ^ exists(step_sizes)) + + if exists(signal_levels): + signal_embed = self.signal_levels_embed(signal_levels) + step_size_embed = self.step_sizes_embed(step_sizes) + + flow_token = cat((signal_embed, step_size_embed), dim = -1) + flow_token = rearrange(flow_token, 'b t d -> b t d') + else: + flow_token = registers[..., 0:0, :] + + # pack to tokens for attending + + tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d') # attention @@ -583,7 +613,7 @@ class DynamicsModel(Module): # unpack - space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d') + flow_token, space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d') # pooling diff --git a/pyproject.toml b/pyproject.toml index 3b95b9f..d914450 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,8 @@ dependencies = [ "accelerate", "einx>=0.3.0", "einops>=0.8.1", - "torch>=2.4" + "torch>=2.4", + "x-mlps-pytorch" ] [project.urls] diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index ea9c162..6257b59 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -5,7 +5,7 @@ def test_e2e(): from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32) - x = torch.randn(1, 3, 4, 256, 256) + x = torch.randn(2, 3, 4, 256, 256) loss = tokenizer(x) assert loss.numel() == 1 @@ -13,6 +13,10 @@ def test_e2e(): latents = tokenizer(x, return_latents = True) assert latents.shape[-1] == 32 - dynamics = DynamicsModel(512, dim_latent = 32) - pred = dynamics(latents) + dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32) + + signal_levels = torch.randint(0, 500, (2, 4)) + step_sizes = torch.randint(0, 32, (2, 4)) + + pred = dynamics(latents, signal_levels = signal_levels, step_sizes = step_sizes) assert pred.shape == latents.shape