one more incision before knocking out reward decoding
This commit is contained in:
parent
c056835aea
commit
f7bdaddbbb
@ -215,8 +215,10 @@ class SymExpTwoHot(Module):
|
||||
|
||||
def bins_to_scalar_value(
|
||||
self,
|
||||
two_hot_encoding # (... l)
|
||||
logits, # (... l)
|
||||
normalize = False
|
||||
):
|
||||
two_hot_encoding = logits.softmax(dim = -1) if normalize else logits
|
||||
return einsum(two_hot_encoding, self.bin_values, '... l, l -> ...')
|
||||
|
||||
def forward(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user