diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index a669d0d..b10ec76 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -5,11 +5,12 @@ from functools import partial import torch import torch.nn.functional as F -from torch.nn import Module, ModuleList, Sequential, Linear, RMSNorm, Identity +from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity from torch import cat, stack, tensor, Tensor, is_tensor # ein related +import einx from einops import einsum, rearrange, repeat, reduce from einops.layers.torch import Rearrange @@ -25,6 +26,32 @@ def exists(v): def default(v, d): return v if exists(v) else d +def l2norm(t): + return F.normalize(t, dim = -1, p = 2) + +def softclamp(t, value = 50.): + return (t / value).tanh() * value + +# multi-head rmsnorm + +class MultiHeadRMSNorm(Module): + def __init__( + self, + dim_head, + heads = 8 + ): + super().__init__() + self.scale = dim_head ** 0.5 + self.gamma = Parameter(torch.zeros(heads, dim_head)) # weight decay friendly + + def forward( + self, + x + ): + normed = l2norm(x) + scale = (self.gamma + 1.) * self.scale + return einx.multiply('... h n d, h d', normed, scale) + # attention class Attention(Module): @@ -33,6 +60,7 @@ class Attention(Module): dim, dim_head = 64, heads = 8, + softclamp_value = 50., pre_rmsnorm = True ): super().__init__() @@ -47,6 +75,13 @@ class Attention(Module): self.to_kv = LinearNoBias(dim, dim_inner * 2) self.to_out = LinearNoBias(dim_inner, dim) + # stability related + + self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) + self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) + + self.softclamp_value = softclamp_value + def forward( self, tokens, @@ -57,23 +92,47 @@ class Attention(Module): q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1)) + # split heads + q, k, v = map(self.split_heads, (q, k, v)) + # qk rmsnorm + + q = self.q_heads_rmsnorm(q) + k = self.k_heads_rmsnorm(k) + + # caching + if exists(kv_cache): ck, cv = kv_cache k = cat((ck, k), dim = -2) v = cat((cv, v), dim = -2) - q = q * self.scale + # similarity sim = einsum(q, k, 'b h i d, b h j d -> b h i j') + # softclamping a la gemma 3 + + if exists(self.softclamp_value): + sim = softclamp(sim, self.softclamp_value) + + # scale and attention + + sim = sim * self.scale + attn = sim.softmax(dim = -1) + # aggregate + out = einsum(attn, v, 'b h i j, b h j d -> b h i d') + # merge heads + out = self.merge_heads(out) + # combine heads + out = self.to_out(out) if not return_kv_cache: diff --git a/pyproject.toml b/pyproject.toml index a82c3cd..989f95b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ classifiers=[ ] dependencies = [ + "einx>=0.3.0", "einops>=0.8.1", "torch>=2.4" ]