diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index d494f77..30029c2 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -74,6 +74,32 @@ def l2norm(t): def softclamp(t, value = 50.): return (t / value).tanh() * value +# reinforcement learning related + +# rewards + +class SymExpTwoHot(Module): + def __init__( + self, + range = (-20., 20.), + bins = 255 + ): + super().__init__() + + values = torch.linspace(-20., 20., bins) + values = values.sign() * (torch.exp(values.abs()) - 1.) + + self.register_buffer('bin_values', values) + + def logits_to_scalar_value( + self, + logits # (... l) + ): + raise NotImplementedError + + def forward(self, x): + raise NotImplementedError + # golden gate rotary - Jerry Xiong, PhD student at UIUC # https://jerryxio.ng/posts/nd-rope/