From 046f8927d128a869013e241272499becdbf92fbd Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 3 Oct 2025 08:07:57 -0700 Subject: [PATCH] complete the symexp two hot proposed by Hafner from the previous versions of Dreamer, but will also bring in hl gauss --- dreamer4/dreamer4.py | 54 +++++++++++++++++++++++++++++++++++++++---- pyproject.toml | 1 + tests/test_dreamer.py | 12 ++++++++++ 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 30029c2..17bc894 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -86,19 +86,65 @@ class SymExpTwoHot(Module): ): super().__init__() - values = torch.linspace(-20., 20., bins) + min_value, max_value = range + values = torch.linspace(min_value, max_value, bins) values = values.sign() * (torch.exp(values.abs()) - 1.) + self.num_bins = bins self.register_buffer('bin_values', values) + @property + def device(self): + return self.bin_values.device + def logits_to_scalar_value( self, logits # (... l) ): - raise NotImplementedError + return einsum(logits, self.bin_values, '... l, l -> ...') - def forward(self, x): - raise NotImplementedError + def forward( + self, + values + ): + bin_values = self.bin_values + min_bin_value, max_bin_value = self.bin_values[0], self.bin_values[-1] + + values, inverse_pack = pack_one(values, '*') + num_values = values.shape[0] + + values = values.clamp(min = min_bin_value, max = max_bin_value) + + indices = torch.searchsorted(self.bin_values, values) + + # fetch the closest two indices (two-hot encoding) + + left_indices = (indices - 1).clamp(min = 0) + right_indices = left_indices + 1 + + left_indices, right_indices = tuple(rearrange(t, '... -> ... 1') for t in (left_indices, right_indices)) + + # fetch the left and right values for the consecutive indices + + left_values = self.bin_values[left_indices] + right_values = self.bin_values[right_indices] + + # calculate the left and right values by the distance to the left and right + + values = rearrange(values, '... -> ... 1') + total_distance = right_values - left_values + + left_logit_value = (right_values - values) / total_distance + right_logit_value = 1. - left_logit_value + + # set the left and right values (two-hot) + + encoded = torch.zeros((num_values, self.num_bins), device = self.device) + + encoded.scatter_(-1, left_indices, left_logit_value) + encoded.scatter_(-1, right_indices, right_logit_value) + + return inverse_pack(encoded, '* l') # golden gate rotary - Jerry Xiong, PhD student at UIUC # https://jerryxio.ng/posts/nd-rope/ diff --git a/pyproject.toml b/pyproject.toml index d914450..5b56333 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "accelerate", "einx>=0.3.0", "einops>=0.8.1", + "hl-gauss-pytorch", "torch>=2.4", "x-mlps-pytorch" ] diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index d64ac69..3a7c2a8 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -24,3 +24,15 @@ def test_e2e( flow_loss = dynamics(latents, signal_levels = signal_levels, step_sizes = step_sizes) assert flow_loss.numel() == 1 + +def test_symexp_two_hot(): + import torch + from dreamer4.dreamer4 import SymExpTwoHot + + two_hot_encoder = SymExpTwoHot((-3., 3.), 20) + values = torch.randn((10)) + + encoded = two_hot_encoder(values) + recon_values = two_hot_encoder.logits_to_scalar_value(encoded) + + assert torch.allclose(recon_values, values, atol = 1e-6)