diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index fbf7d29..b661a24 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1078,7 +1078,9 @@ class Attention(Module): query_heads = None, heads = 8, pre_rmsnorm = True, - gate_values = True + gate_values = True, + rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2 + rmsnorm_key = True ): super().__init__() self.norm = RMSNorm(dim) if pre_rmsnorm else Identity() @@ -1114,8 +1116,8 @@ class Attention(Module): # stability related - self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) - self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) + self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity() + self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity() def muon_parameters(self): return [ diff --git a/pyproject.toml b/pyproject.toml index 5275757..4269618 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.47" +version = "0.0.48" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }