From e3cbcd94c6c9802d7065363a4dd2f45373392726 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 1 Oct 2025 10:25:56 -0700 Subject: [PATCH] sketch out top down --- dreamer4/__init__.py | 5 +++++ dreamer4/dreamer4.py | 51 ++++++++++++++++++++++++++++++++++---------- dreamer4/trainers.py | 3 +++ pyproject.toml | 1 + 4 files changed, 49 insertions(+), 11 deletions(-) create mode 100644 dreamer4/trainers.py diff --git a/dreamer4/__init__.py b/dreamer4/__init__.py index e69de29..59a245a 100644 --- a/dreamer4/__init__.py +++ b/dreamer4/__init__.py @@ -0,0 +1,5 @@ +from dreamer4.dreamer4 import ( + VideoTokenizer, + DynamicsModel, + Dreamer +) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 9198091..07643b7 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -8,6 +8,8 @@ import torch.nn.functional as F from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity from torch import cat, stack, arange, tensor, Tensor, is_tensor +from accelerate import Accelerator + # ein related # b - batch @@ -101,12 +103,8 @@ class GoldenGateRoPENd(Module): def forward( self, - x, # (b h n d) pos # (b n p) ): - dtype = x - - x, y = x.float().chunk(2, dim = -1) # (b, h, n, f) freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p') positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p') @@ -115,16 +113,25 @@ class GoldenGateRoPENd(Module): theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum') - # apply rotations + return theta - cos_theta = torch.cos(theta) - sin_theta = torch.sin(theta) +def apply_rotations( + theta # (b h n f) +): + dtype = x - x_out = x * cos_theta - y * sin_theta - y_out = x * sin_theta + y * cos_theta + x, y = rearrange(x.float(), '... (split d) -> split ... d', split = 2) # (b, h, n, f) - out = cat((x_out, y_out), dim = -1) - return out.type_as(dtype) + # apply rotations + + cos_theta = torch.cos(theta) + sin_theta = torch.sin(theta) + + x_out = x * cos_theta - y * sin_theta + y_out = x * sin_theta + y * cos_theta + + out = rearrange([x_out, y_out], 'split ... d -> ... (split d)') + return out.type_as(dtype) # multi-head rmsnorm @@ -279,3 +286,25 @@ class SwiGLUFeedforward(Module): x = x * F.gelu(gates) return self.proj_out(x) + +# video tokenizer + +class VideoTokenizer(Module): + def __init__( + self + ): + super().__init__() + +class DynamicsModel(Module): + def __init__( + self + ): + super().__init__() + +class Dreamer(Module): + def __init__( + self, + video_tokenizer: VideoTokenizer, + dynamics_model: DynamicsModel + ): + super().__init__() diff --git a/dreamer4/trainers.py b/dreamer4/trainers.py new file mode 100644 index 0000000..b15ff8a --- /dev/null +++ b/dreamer4/trainers.py @@ -0,0 +1,3 @@ +import torch + +from accelerate import Accelerator diff --git a/pyproject.toml b/pyproject.toml index 989f95b..3b95b9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ classifiers=[ ] dependencies = [ + "accelerate", "einx>=0.3.0", "einops>=0.8.1", "torch>=2.4"