last commit for the day
This commit is contained in:
parent
8d1cd311bb
commit
2a896ab01d
@ -74,6 +74,32 @@ def l2norm(t):
|
|||||||
def softclamp(t, value = 50.):
|
def softclamp(t, value = 50.):
|
||||||
return (t / value).tanh() * value
|
return (t / value).tanh() * value
|
||||||
|
|
||||||
|
# reinforcement learning related
|
||||||
|
|
||||||
|
# rewards
|
||||||
|
|
||||||
|
class SymExpTwoHot(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
range = (-20., 20.),
|
||||||
|
bins = 255
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
values = torch.linspace(-20., 20., bins)
|
||||||
|
values = values.sign() * (torch.exp(values.abs()) - 1.)
|
||||||
|
|
||||||
|
self.register_buffer('bin_values', values)
|
||||||
|
|
||||||
|
def logits_to_scalar_value(
|
||||||
|
self,
|
||||||
|
logits # (... l)
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
||||||
# https://jerryxio.ng/posts/nd-rope/
|
# https://jerryxio.ng/posts/nd-rope/
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user