allow for only rmsnorm for keys in attention
This commit is contained in:
parent
1345326656
commit
a7e0c395c3
@ -1078,7 +1078,9 @@ class Attention(Module):
|
||||
query_heads = None,
|
||||
heads = 8,
|
||||
pre_rmsnorm = True,
|
||||
gate_values = True
|
||||
gate_values = True,
|
||||
rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2
|
||||
rmsnorm_key = True
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
||||
@ -1114,8 +1116,8 @@ class Attention(Module):
|
||||
|
||||
# stability related
|
||||
|
||||
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
|
||||
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
||||
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity()
|
||||
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
|
||||
|
||||
def muon_parameters(self):
|
||||
return [
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.47"
|
||||
version = "0.0.48"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user