sketch out top down

This commit is contained in:
lucidrains 2025-10-01 10:25:56 -07:00
parent 882e63511b
commit e3cbcd94c6
4 changed files with 49 additions and 11 deletions

View File

@ -0,0 +1,5 @@
from dreamer4.dreamer4 import (
VideoTokenizer,
DynamicsModel,
Dreamer
)

View File

@ -8,6 +8,8 @@ import torch.nn.functional as F
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import cat, stack, arange, tensor, Tensor, is_tensor
from accelerate import Accelerator
# ein related
# b - batch
@ -101,12 +103,8 @@ class GoldenGateRoPENd(Module):
def forward(
self,
x, # (b h n d)
pos # (b n p)
):
dtype = x
x, y = x.float().chunk(2, dim = -1) # (b, h, n, f)
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
@ -115,16 +113,25 @@ class GoldenGateRoPENd(Module):
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
# apply rotations
return theta
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
def apply_rotations(
theta # (b h n f)
):
dtype = x
x_out = x * cos_theta - y * sin_theta
y_out = x * sin_theta + y * cos_theta
x, y = rearrange(x.float(), '... (split d) -> split ... d', split = 2) # (b, h, n, f)
out = cat((x_out, y_out), dim = -1)
return out.type_as(dtype)
# apply rotations
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
x_out = x * cos_theta - y * sin_theta
y_out = x * sin_theta + y * cos_theta
out = rearrange([x_out, y_out], 'split ... d -> ... (split d)')
return out.type_as(dtype)
# multi-head rmsnorm
@ -279,3 +286,25 @@ class SwiGLUFeedforward(Module):
x = x * F.gelu(gates)
return self.proj_out(x)
# video tokenizer
class VideoTokenizer(Module):
def __init__(
self
):
super().__init__()
class DynamicsModel(Module):
def __init__(
self
):
super().__init__()
class Dreamer(Module):
def __init__(
self,
video_tokenizer: VideoTokenizer,
dynamics_model: DynamicsModel
):
super().__init__()

3
dreamer4/trainers.py Normal file
View File

@ -0,0 +1,3 @@
import torch
from accelerate import Accelerator

View File

@ -26,6 +26,7 @@ classifiers=[
]
dependencies = [
"accelerate",
"einx>=0.3.0",
"einops>=0.8.1",
"torch>=2.4"