dreamer4/dreamer4/dreamer4.py
2025-10-01 09:28:25 -07:00

108 lines
2.4 KiB
Python

from __future__ import annotations
import math
from functools import partial
import torch
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Sequential, Linear, RMSNorm, Identity
from torch import cat, stack, tensor, Tensor, is_tensor
# ein related
from einops import einsum, rearrange, repeat, reduce
from einops.layers.torch import Rearrange
# constants
LinearNoBias = partial(Linear, bias = False)
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
# attention
class Attention(Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
pre_rmsnorm = True
):
super().__init__()
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
self.scale = dim_head ** -0.5
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.merge_heads = Rearrange('b h n d -> b n (h d)')
dim_inner = dim_head * heads
self.to_q = LinearNoBias(dim, dim_inner)
self.to_kv = LinearNoBias(dim, dim_inner * 2)
self.to_out = LinearNoBias(dim_inner, dim)
def forward(
self,
tokens,
kv_cache = None,
return_kv_cache = False
):
tokens = self.norm(tokens)
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
q, k, v = map(self.split_heads, (q, k, v))
if exists(kv_cache):
ck, cv = kv_cache
k = cat((ck, k), dim = -2)
v = cat((cv, v), dim = -2)
q = q * self.scale
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
attn = sim.softmax(dim = -1)
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
out = self.merge_heads(out)
out = self.to_out(out)
if not return_kv_cache:
return out
return out, stack((k, v))
# feedforward
class SwiGLUFeedforward(Module):
def __init__(
self,
dim,
expansion_factor = 4,
pre_rmsnorm = True
):
super().__init__()
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
dim_inner = int(dim * expansion_factor * 2 / 3)
self.proj_in = Linear(dim, dim_inner * 2)
self.proj_out = Linear(dim_inner, dim)
def forward(self, x):
x = self.norm(x)
x, gates = self.proj_in(x).chunk(2, dim = -1)
x = x * F.gelu(gates)
return self.proj_out(x)