diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 972833b..a669d0d 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -5,7 +5,7 @@ from functools import partial import torch import torch.nn.functional as F -from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity +from torch.nn import Module, ModuleList, Sequential, Linear, RMSNorm, Identity from torch import cat, stack, tensor, Tensor, is_tensor # ein related @@ -25,7 +25,7 @@ def exists(v): def default(v, d): return v if exists(v) else d -# classes +# attention class Attention(Module): def __init__( @@ -80,3 +80,28 @@ class Attention(Module): return out return out, stack((k, v)) + +# feedforward + +class SwiGLUFeedforward(Module): + def __init__( + self, + dim, + expansion_factor = 4, + pre_rmsnorm = True + ): + super().__init__() + self.norm = RMSNorm(dim) if pre_rmsnorm else Identity() + + dim_inner = int(dim * expansion_factor * 2 / 3) + + self.proj_in = Linear(dim, dim_inner * 2) + self.proj_out = Linear(dim_inner, dim) + + def forward(self, x): + x = self.norm(x) + + x, gates = self.proj_in(x).chunk(2, dim = -1) + x = x * F.gelu(gates) + + return self.proj_out(x) diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index f365e9c..701f452 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -1,3 +1,17 @@ +import pytest +import torch -def test_dreamer(): - assert True +def test_attn(): + from dreamer4.dreamer4 import Attention + + x = torch.randn(1, 1024, 512) + attn = Attention(512) + + assert attn(x).shape == x.shape + +def test_ff(): + from dreamer4.dreamer4 import SwiGLUFeedforward + x = torch.randn(1, 1024, 512) + ff = SwiGLUFeedforward(512) + + assert ff(x).shape == x.shape