swish glu feedforward from shazeer et al

This commit is contained in:
lucidrains 2025-10-01 09:28:25 -07:00
parent 8ebb8a9661
commit e8678364ba
2 changed files with 43 additions and 4 deletions

View File

@ -5,7 +5,7 @@ from functools import partial
import torch
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity
from torch.nn import Module, ModuleList, Sequential, Linear, RMSNorm, Identity
from torch import cat, stack, tensor, Tensor, is_tensor
# ein related
@ -25,7 +25,7 @@ def exists(v):
def default(v, d):
return v if exists(v) else d
# classes
# attention
class Attention(Module):
def __init__(
@ -80,3 +80,28 @@ class Attention(Module):
return out
return out, stack((k, v))
# feedforward
class SwiGLUFeedforward(Module):
def __init__(
self,
dim,
expansion_factor = 4,
pre_rmsnorm = True
):
super().__init__()
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
dim_inner = int(dim * expansion_factor * 2 / 3)
self.proj_in = Linear(dim, dim_inner * 2)
self.proj_out = Linear(dim_inner, dim)
def forward(self, x):
x = self.norm(x)
x, gates = self.proj_in(x).chunk(2, dim = -1)
x = x * F.gelu(gates)
return self.proj_out(x)

View File

@ -1,3 +1,17 @@
import pytest
import torch
def test_dreamer():
assert True
def test_attn():
from dreamer4.dreamer4 import Attention
x = torch.randn(1, 1024, 512)
attn = Attention(512)
assert attn(x).shape == x.shape
def test_ff():
from dreamer4.dreamer4 import SwiGLUFeedforward
x = torch.randn(1, 1024, 512)
ff = SwiGLUFeedforward(512)
assert ff(x).shape == x.shape