swish glu feedforward from shazeer et al
This commit is contained in:
parent
8ebb8a9661
commit
e8678364ba
@ -5,7 +5,7 @@ from functools import partial
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity
|
from torch.nn import Module, ModuleList, Sequential, Linear, RMSNorm, Identity
|
||||||
from torch import cat, stack, tensor, Tensor, is_tensor
|
from torch import cat, stack, tensor, Tensor, is_tensor
|
||||||
|
|
||||||
# ein related
|
# ein related
|
||||||
@ -25,7 +25,7 @@ def exists(v):
|
|||||||
def default(v, d):
|
def default(v, d):
|
||||||
return v if exists(v) else d
|
return v if exists(v) else d
|
||||||
|
|
||||||
# classes
|
# attention
|
||||||
|
|
||||||
class Attention(Module):
|
class Attention(Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -80,3 +80,28 @@ class Attention(Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
return out, stack((k, v))
|
return out, stack((k, v))
|
||||||
|
|
||||||
|
# feedforward
|
||||||
|
|
||||||
|
class SwiGLUFeedforward(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
expansion_factor = 4,
|
||||||
|
pre_rmsnorm = True
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
||||||
|
|
||||||
|
dim_inner = int(dim * expansion_factor * 2 / 3)
|
||||||
|
|
||||||
|
self.proj_in = Linear(dim, dim_inner * 2)
|
||||||
|
self.proj_out = Linear(dim_inner, dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
x, gates = self.proj_in(x).chunk(2, dim = -1)
|
||||||
|
x = x * F.gelu(gates)
|
||||||
|
|
||||||
|
return self.proj_out(x)
|
||||||
|
|||||||
@ -1,3 +1,17 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
def test_dreamer():
|
def test_attn():
|
||||||
assert True
|
from dreamer4.dreamer4 import Attention
|
||||||
|
|
||||||
|
x = torch.randn(1, 1024, 512)
|
||||||
|
attn = Attention(512)
|
||||||
|
|
||||||
|
assert attn(x).shape == x.shape
|
||||||
|
|
||||||
|
def test_ff():
|
||||||
|
from dreamer4.dreamer4 import SwiGLUFeedforward
|
||||||
|
x = torch.randn(1, 1024, 512)
|
||||||
|
ff = SwiGLUFeedforward(512)
|
||||||
|
|
||||||
|
assert ff(x).shape == x.shape
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user