allow for only rmsnorm for keys in attention

This commit is contained in:
lucidrains 2025-10-20 11:20:49 -07:00
parent 1345326656
commit a7e0c395c3
2 changed files with 6 additions and 4 deletions

View File

@ -1078,7 +1078,9 @@ class Attention(Module):
query_heads = None, query_heads = None,
heads = 8, heads = 8,
pre_rmsnorm = True, pre_rmsnorm = True,
gate_values = True gate_values = True,
rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2
rmsnorm_key = True
): ):
super().__init__() super().__init__()
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity() self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
@ -1114,8 +1116,8 @@ class Attention(Module):
# stability related # stability related
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity()
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
def muon_parameters(self): def muon_parameters(self):
return [ return [

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.47" version = "0.0.48"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }