swish glu feedforward from shazeer et al
This commit is contained in:
parent
8ebb8a9661
commit
e8678364ba
@ -5,7 +5,7 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity
|
||||
from torch.nn import Module, ModuleList, Sequential, Linear, RMSNorm, Identity
|
||||
from torch import cat, stack, tensor, Tensor, is_tensor
|
||||
|
||||
# ein related
|
||||
@ -25,7 +25,7 @@ def exists(v):
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
# classes
|
||||
# attention
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
@ -80,3 +80,28 @@ class Attention(Module):
|
||||
return out
|
||||
|
||||
return out, stack((k, v))
|
||||
|
||||
# feedforward
|
||||
|
||||
class SwiGLUFeedforward(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
expansion_factor = 4,
|
||||
pre_rmsnorm = True
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
||||
|
||||
dim_inner = int(dim * expansion_factor * 2 / 3)
|
||||
|
||||
self.proj_in = Linear(dim, dim_inner * 2)
|
||||
self.proj_out = Linear(dim_inner, dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
x, gates = self.proj_in(x).chunk(2, dim = -1)
|
||||
x = x * F.gelu(gates)
|
||||
|
||||
return self.proj_out(x)
|
||||
|
||||
@ -1,3 +1,17 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
def test_dreamer():
|
||||
assert True
|
||||
def test_attn():
|
||||
from dreamer4.dreamer4 import Attention
|
||||
|
||||
x = torch.randn(1, 1024, 512)
|
||||
attn = Attention(512)
|
||||
|
||||
assert attn(x).shape == x.shape
|
||||
|
||||
def test_ff():
|
||||
from dreamer4.dreamer4 import SwiGLUFeedforward
|
||||
x = torch.randn(1, 1024, 512)
|
||||
ff = SwiGLUFeedforward(512)
|
||||
|
||||
assert ff(x).shape == x.shape
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user