complete the symexp two hot proposed by Hafner from the previous versions of Dreamer, but will also bring in hl gauss
This commit is contained in:
parent
2a896ab01d
commit
046f8927d1
@ -86,19 +86,65 @@ class SymExpTwoHot(Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
values = torch.linspace(-20., 20., bins)
|
||||
min_value, max_value = range
|
||||
values = torch.linspace(min_value, max_value, bins)
|
||||
values = values.sign() * (torch.exp(values.abs()) - 1.)
|
||||
|
||||
self.num_bins = bins
|
||||
self.register_buffer('bin_values', values)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.bin_values.device
|
||||
|
||||
def logits_to_scalar_value(
|
||||
self,
|
||||
logits # (... l)
|
||||
):
|
||||
raise NotImplementedError
|
||||
return einsum(logits, self.bin_values, '... l, l -> ...')
|
||||
|
||||
def forward(self, x):
|
||||
raise NotImplementedError
|
||||
def forward(
|
||||
self,
|
||||
values
|
||||
):
|
||||
bin_values = self.bin_values
|
||||
min_bin_value, max_bin_value = self.bin_values[0], self.bin_values[-1]
|
||||
|
||||
values, inverse_pack = pack_one(values, '*')
|
||||
num_values = values.shape[0]
|
||||
|
||||
values = values.clamp(min = min_bin_value, max = max_bin_value)
|
||||
|
||||
indices = torch.searchsorted(self.bin_values, values)
|
||||
|
||||
# fetch the closest two indices (two-hot encoding)
|
||||
|
||||
left_indices = (indices - 1).clamp(min = 0)
|
||||
right_indices = left_indices + 1
|
||||
|
||||
left_indices, right_indices = tuple(rearrange(t, '... -> ... 1') for t in (left_indices, right_indices))
|
||||
|
||||
# fetch the left and right values for the consecutive indices
|
||||
|
||||
left_values = self.bin_values[left_indices]
|
||||
right_values = self.bin_values[right_indices]
|
||||
|
||||
# calculate the left and right values by the distance to the left and right
|
||||
|
||||
values = rearrange(values, '... -> ... 1')
|
||||
total_distance = right_values - left_values
|
||||
|
||||
left_logit_value = (right_values - values) / total_distance
|
||||
right_logit_value = 1. - left_logit_value
|
||||
|
||||
# set the left and right values (two-hot)
|
||||
|
||||
encoded = torch.zeros((num_values, self.num_bins), device = self.device)
|
||||
|
||||
encoded.scatter_(-1, left_indices, left_logit_value)
|
||||
encoded.scatter_(-1, right_indices, right_logit_value)
|
||||
|
||||
return inverse_pack(encoded, '* l')
|
||||
|
||||
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
||||
# https://jerryxio.ng/posts/nd-rope/
|
||||
|
||||
@ -29,6 +29,7 @@ dependencies = [
|
||||
"accelerate",
|
||||
"einx>=0.3.0",
|
||||
"einops>=0.8.1",
|
||||
"hl-gauss-pytorch",
|
||||
"torch>=2.4",
|
||||
"x-mlps-pytorch"
|
||||
]
|
||||
|
||||
@ -24,3 +24,15 @@ def test_e2e(
|
||||
|
||||
flow_loss = dynamics(latents, signal_levels = signal_levels, step_sizes = step_sizes)
|
||||
assert flow_loss.numel() == 1
|
||||
|
||||
def test_symexp_two_hot():
|
||||
import torch
|
||||
from dreamer4.dreamer4 import SymExpTwoHot
|
||||
|
||||
two_hot_encoder = SymExpTwoHot((-3., 3.), 20)
|
||||
values = torch.randn((10))
|
||||
|
||||
encoded = two_hot_encoder(values)
|
||||
recon_values = two_hot_encoder.logits_to_scalar_value(encoded)
|
||||
|
||||
assert torch.allclose(recon_values, values, atol = 1e-6)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user