one more incision before knocking out reward decoding
This commit is contained in:
parent
c056835aea
commit
f7bdaddbbb
@ -215,8 +215,10 @@ class SymExpTwoHot(Module):
|
|||||||
|
|
||||||
def bins_to_scalar_value(
|
def bins_to_scalar_value(
|
||||||
self,
|
self,
|
||||||
two_hot_encoding # (... l)
|
logits, # (... l)
|
||||||
|
normalize = False
|
||||||
):
|
):
|
||||||
|
two_hot_encoding = logits.softmax(dim = -1) if normalize else logits
|
||||||
return einsum(two_hot_encoding, self.bin_values, '... l, l -> ...')
|
return einsum(two_hot_encoding, self.bin_values, '... l, l -> ...')
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user