sketch out top down
This commit is contained in:
parent
882e63511b
commit
e3cbcd94c6
@ -0,0 +1,5 @@
|
|||||||
|
from dreamer4.dreamer4 import (
|
||||||
|
VideoTokenizer,
|
||||||
|
DynamicsModel,
|
||||||
|
Dreamer
|
||||||
|
)
|
||||||
@ -8,6 +8,8 @@ import torch.nn.functional as F
|
|||||||
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||||
from torch import cat, stack, arange, tensor, Tensor, is_tensor
|
from torch import cat, stack, arange, tensor, Tensor, is_tensor
|
||||||
|
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
# ein related
|
# ein related
|
||||||
|
|
||||||
# b - batch
|
# b - batch
|
||||||
@ -101,12 +103,8 @@ class GoldenGateRoPENd(Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x, # (b h n d)
|
|
||||||
pos # (b n p)
|
pos # (b n p)
|
||||||
):
|
):
|
||||||
dtype = x
|
|
||||||
|
|
||||||
x, y = x.float().chunk(2, dim = -1) # (b, h, n, f)
|
|
||||||
|
|
||||||
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
|
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
|
||||||
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
|
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
|
||||||
@ -115,16 +113,25 @@ class GoldenGateRoPENd(Module):
|
|||||||
|
|
||||||
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
|
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
|
||||||
|
|
||||||
# apply rotations
|
return theta
|
||||||
|
|
||||||
cos_theta = torch.cos(theta)
|
def apply_rotations(
|
||||||
sin_theta = torch.sin(theta)
|
theta # (b h n f)
|
||||||
|
):
|
||||||
|
dtype = x
|
||||||
|
|
||||||
x_out = x * cos_theta - y * sin_theta
|
x, y = rearrange(x.float(), '... (split d) -> split ... d', split = 2) # (b, h, n, f)
|
||||||
y_out = x * sin_theta + y * cos_theta
|
|
||||||
|
|
||||||
out = cat((x_out, y_out), dim = -1)
|
# apply rotations
|
||||||
return out.type_as(dtype)
|
|
||||||
|
cos_theta = torch.cos(theta)
|
||||||
|
sin_theta = torch.sin(theta)
|
||||||
|
|
||||||
|
x_out = x * cos_theta - y * sin_theta
|
||||||
|
y_out = x * sin_theta + y * cos_theta
|
||||||
|
|
||||||
|
out = rearrange([x_out, y_out], 'split ... d -> ... (split d)')
|
||||||
|
return out.type_as(dtype)
|
||||||
|
|
||||||
# multi-head rmsnorm
|
# multi-head rmsnorm
|
||||||
|
|
||||||
@ -279,3 +286,25 @@ class SwiGLUFeedforward(Module):
|
|||||||
x = x * F.gelu(gates)
|
x = x * F.gelu(gates)
|
||||||
|
|
||||||
return self.proj_out(x)
|
return self.proj_out(x)
|
||||||
|
|
||||||
|
# video tokenizer
|
||||||
|
|
||||||
|
class VideoTokenizer(Module):
|
||||||
|
def __init__(
|
||||||
|
self
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
class DynamicsModel(Module):
|
||||||
|
def __init__(
|
||||||
|
self
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
class Dreamer(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
video_tokenizer: VideoTokenizer,
|
||||||
|
dynamics_model: DynamicsModel
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|||||||
3
dreamer4/trainers.py
Normal file
3
dreamer4/trainers.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from accelerate import Accelerator
|
||||||
@ -26,6 +26,7 @@ classifiers=[
|
|||||||
]
|
]
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"accelerate",
|
||||||
"einx>=0.3.0",
|
"einx>=0.3.0",
|
||||||
"einops>=0.8.1",
|
"einops>=0.8.1",
|
||||||
"torch>=2.4"
|
"torch>=2.4"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user