allow for only rmsnorm for keys in attention

This commit is contained in:
lucidrains 2025-10-20 11:20:49 -07:00
parent 1345326656
commit a7e0c395c3
2 changed files with 6 additions and 4 deletions

View File

@ -1078,7 +1078,9 @@ class Attention(Module):
query_heads = None,
heads = 8,
pre_rmsnorm = True,
gate_values = True
gate_values = True,
rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2
rmsnorm_key = True
):
super().__init__()
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
@ -1114,8 +1116,8 @@ class Attention(Module):
# stability related
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity()
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
def muon_parameters(self):
return [

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.47"
version = "0.0.48"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }