From f7bdaddbbb79015e8d4352a37520d5fffcfafc0e Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 8 Oct 2025 06:11:02 -0700 Subject: [PATCH] one more incision before knocking out reward decoding --- dreamer4/dreamer4.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 2f1c20f..c174ad2 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -215,8 +215,10 @@ class SymExpTwoHot(Module): def bins_to_scalar_value( self, - two_hot_encoding # (... l) + logits, # (... l) + normalize = False ): + two_hot_encoding = logits.softmax(dim = -1) if normalize else logits return einsum(two_hot_encoding, self.bin_values, '... l, l -> ...') def forward(