complete the symexp two hot proposed by Hafner from the previous versions of Dreamer, but will also bring in hl gauss

This commit is contained in:
lucidrains 2025-10-03 08:07:57 -07:00
parent 2a896ab01d
commit 046f8927d1
3 changed files with 63 additions and 4 deletions

View File

@ -86,19 +86,65 @@ class SymExpTwoHot(Module):
):
super().__init__()
values = torch.linspace(-20., 20., bins)
min_value, max_value = range
values = torch.linspace(min_value, max_value, bins)
values = values.sign() * (torch.exp(values.abs()) - 1.)
self.num_bins = bins
self.register_buffer('bin_values', values)
@property
def device(self):
return self.bin_values.device
def logits_to_scalar_value(
self,
logits # (... l)
):
raise NotImplementedError
return einsum(logits, self.bin_values, '... l, l -> ...')
def forward(self, x):
raise NotImplementedError
def forward(
self,
values
):
bin_values = self.bin_values
min_bin_value, max_bin_value = self.bin_values[0], self.bin_values[-1]
values, inverse_pack = pack_one(values, '*')
num_values = values.shape[0]
values = values.clamp(min = min_bin_value, max = max_bin_value)
indices = torch.searchsorted(self.bin_values, values)
# fetch the closest two indices (two-hot encoding)
left_indices = (indices - 1).clamp(min = 0)
right_indices = left_indices + 1
left_indices, right_indices = tuple(rearrange(t, '... -> ... 1') for t in (left_indices, right_indices))
# fetch the left and right values for the consecutive indices
left_values = self.bin_values[left_indices]
right_values = self.bin_values[right_indices]
# calculate the left and right values by the distance to the left and right
values = rearrange(values, '... -> ... 1')
total_distance = right_values - left_values
left_logit_value = (right_values - values) / total_distance
right_logit_value = 1. - left_logit_value
# set the left and right values (two-hot)
encoded = torch.zeros((num_values, self.num_bins), device = self.device)
encoded.scatter_(-1, left_indices, left_logit_value)
encoded.scatter_(-1, right_indices, right_logit_value)
return inverse_pack(encoded, '* l')
# golden gate rotary - Jerry Xiong, PhD student at UIUC
# https://jerryxio.ng/posts/nd-rope/

View File

@ -29,6 +29,7 @@ dependencies = [
"accelerate",
"einx>=0.3.0",
"einops>=0.8.1",
"hl-gauss-pytorch",
"torch>=2.4",
"x-mlps-pytorch"
]

View File

@ -24,3 +24,15 @@ def test_e2e(
flow_loss = dynamics(latents, signal_levels = signal_levels, step_sizes = step_sizes)
assert flow_loss.numel() == 1
def test_symexp_two_hot():
import torch
from dreamer4.dreamer4 import SymExpTwoHot
two_hot_encoder = SymExpTwoHot((-3., 3.), 20)
values = torch.randn((10))
encoded = two_hot_encoder(values)
recon_values = two_hot_encoder.logits_to_scalar_value(encoded)
assert torch.allclose(recon_values, values, atol = 1e-6)