sketch out top down
This commit is contained in:
parent
882e63511b
commit
e3cbcd94c6
@ -0,0 +1,5 @@
|
||||
from dreamer4.dreamer4 import (
|
||||
VideoTokenizer,
|
||||
DynamicsModel,
|
||||
Dreamer
|
||||
)
|
||||
@ -8,6 +8,8 @@ import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||
from torch import cat, stack, arange, tensor, Tensor, is_tensor
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
# ein related
|
||||
|
||||
# b - batch
|
||||
@ -101,12 +103,8 @@ class GoldenGateRoPENd(Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x, # (b h n d)
|
||||
pos # (b n p)
|
||||
):
|
||||
dtype = x
|
||||
|
||||
x, y = x.float().chunk(2, dim = -1) # (b, h, n, f)
|
||||
|
||||
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
|
||||
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
|
||||
@ -115,16 +113,25 @@ class GoldenGateRoPENd(Module):
|
||||
|
||||
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
|
||||
|
||||
# apply rotations
|
||||
return theta
|
||||
|
||||
cos_theta = torch.cos(theta)
|
||||
sin_theta = torch.sin(theta)
|
||||
def apply_rotations(
|
||||
theta # (b h n f)
|
||||
):
|
||||
dtype = x
|
||||
|
||||
x_out = x * cos_theta - y * sin_theta
|
||||
y_out = x * sin_theta + y * cos_theta
|
||||
x, y = rearrange(x.float(), '... (split d) -> split ... d', split = 2) # (b, h, n, f)
|
||||
|
||||
out = cat((x_out, y_out), dim = -1)
|
||||
return out.type_as(dtype)
|
||||
# apply rotations
|
||||
|
||||
cos_theta = torch.cos(theta)
|
||||
sin_theta = torch.sin(theta)
|
||||
|
||||
x_out = x * cos_theta - y * sin_theta
|
||||
y_out = x * sin_theta + y * cos_theta
|
||||
|
||||
out = rearrange([x_out, y_out], 'split ... d -> ... (split d)')
|
||||
return out.type_as(dtype)
|
||||
|
||||
# multi-head rmsnorm
|
||||
|
||||
@ -279,3 +286,25 @@ class SwiGLUFeedforward(Module):
|
||||
x = x * F.gelu(gates)
|
||||
|
||||
return self.proj_out(x)
|
||||
|
||||
# video tokenizer
|
||||
|
||||
class VideoTokenizer(Module):
|
||||
def __init__(
|
||||
self
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
class DynamicsModel(Module):
|
||||
def __init__(
|
||||
self
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
class Dreamer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
video_tokenizer: VideoTokenizer,
|
||||
dynamics_model: DynamicsModel
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
3
dreamer4/trainers.py
Normal file
3
dreamer4/trainers.py
Normal file
@ -0,0 +1,3 @@
|
||||
import torch
|
||||
|
||||
from accelerate import Accelerator
|
||||
@ -26,6 +26,7 @@ classifiers=[
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"accelerate",
|
||||
"einx>=0.3.0",
|
||||
"einops>=0.8.1",
|
||||
"torch>=2.4"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user