one more incision before knocking out reward decoding

This commit is contained in:
lucidrains 2025-10-08 06:11:02 -07:00
parent c056835aea
commit f7bdaddbbb

View File

@ -215,8 +215,10 @@ class SymExpTwoHot(Module):
def bins_to_scalar_value(
self,
two_hot_encoding # (... l)
logits, # (... l)
normalize = False
):
two_hot_encoding = logits.softmax(dim = -1) if normalize else logits
return einsum(two_hot_encoding, self.bin_values, '... l, l -> ...')
def forward(