complete the symexp two hot proposed by Hafner from the previous versions of Dreamer, but will also bring in hl gauss
This commit is contained in:
parent
2a896ab01d
commit
046f8927d1
@ -86,19 +86,65 @@ class SymExpTwoHot(Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
values = torch.linspace(-20., 20., bins)
|
min_value, max_value = range
|
||||||
|
values = torch.linspace(min_value, max_value, bins)
|
||||||
values = values.sign() * (torch.exp(values.abs()) - 1.)
|
values = values.sign() * (torch.exp(values.abs()) - 1.)
|
||||||
|
|
||||||
|
self.num_bins = bins
|
||||||
self.register_buffer('bin_values', values)
|
self.register_buffer('bin_values', values)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.bin_values.device
|
||||||
|
|
||||||
def logits_to_scalar_value(
|
def logits_to_scalar_value(
|
||||||
self,
|
self,
|
||||||
logits # (... l)
|
logits # (... l)
|
||||||
):
|
):
|
||||||
raise NotImplementedError
|
return einsum(logits, self.bin_values, '... l, l -> ...')
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(
|
||||||
raise NotImplementedError
|
self,
|
||||||
|
values
|
||||||
|
):
|
||||||
|
bin_values = self.bin_values
|
||||||
|
min_bin_value, max_bin_value = self.bin_values[0], self.bin_values[-1]
|
||||||
|
|
||||||
|
values, inverse_pack = pack_one(values, '*')
|
||||||
|
num_values = values.shape[0]
|
||||||
|
|
||||||
|
values = values.clamp(min = min_bin_value, max = max_bin_value)
|
||||||
|
|
||||||
|
indices = torch.searchsorted(self.bin_values, values)
|
||||||
|
|
||||||
|
# fetch the closest two indices (two-hot encoding)
|
||||||
|
|
||||||
|
left_indices = (indices - 1).clamp(min = 0)
|
||||||
|
right_indices = left_indices + 1
|
||||||
|
|
||||||
|
left_indices, right_indices = tuple(rearrange(t, '... -> ... 1') for t in (left_indices, right_indices))
|
||||||
|
|
||||||
|
# fetch the left and right values for the consecutive indices
|
||||||
|
|
||||||
|
left_values = self.bin_values[left_indices]
|
||||||
|
right_values = self.bin_values[right_indices]
|
||||||
|
|
||||||
|
# calculate the left and right values by the distance to the left and right
|
||||||
|
|
||||||
|
values = rearrange(values, '... -> ... 1')
|
||||||
|
total_distance = right_values - left_values
|
||||||
|
|
||||||
|
left_logit_value = (right_values - values) / total_distance
|
||||||
|
right_logit_value = 1. - left_logit_value
|
||||||
|
|
||||||
|
# set the left and right values (two-hot)
|
||||||
|
|
||||||
|
encoded = torch.zeros((num_values, self.num_bins), device = self.device)
|
||||||
|
|
||||||
|
encoded.scatter_(-1, left_indices, left_logit_value)
|
||||||
|
encoded.scatter_(-1, right_indices, right_logit_value)
|
||||||
|
|
||||||
|
return inverse_pack(encoded, '* l')
|
||||||
|
|
||||||
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
||||||
# https://jerryxio.ng/posts/nd-rope/
|
# https://jerryxio.ng/posts/nd-rope/
|
||||||
|
|||||||
@ -29,6 +29,7 @@ dependencies = [
|
|||||||
"accelerate",
|
"accelerate",
|
||||||
"einx>=0.3.0",
|
"einx>=0.3.0",
|
||||||
"einops>=0.8.1",
|
"einops>=0.8.1",
|
||||||
|
"hl-gauss-pytorch",
|
||||||
"torch>=2.4",
|
"torch>=2.4",
|
||||||
"x-mlps-pytorch"
|
"x-mlps-pytorch"
|
||||||
]
|
]
|
||||||
|
|||||||
@ -24,3 +24,15 @@ def test_e2e(
|
|||||||
|
|
||||||
flow_loss = dynamics(latents, signal_levels = signal_levels, step_sizes = step_sizes)
|
flow_loss = dynamics(latents, signal_levels = signal_levels, step_sizes = step_sizes)
|
||||||
assert flow_loss.numel() == 1
|
assert flow_loss.numel() == 1
|
||||||
|
|
||||||
|
def test_symexp_two_hot():
|
||||||
|
import torch
|
||||||
|
from dreamer4.dreamer4 import SymExpTwoHot
|
||||||
|
|
||||||
|
two_hot_encoder = SymExpTwoHot((-3., 3.), 20)
|
||||||
|
values = torch.randn((10))
|
||||||
|
|
||||||
|
encoded = two_hot_encoder(values)
|
||||||
|
recon_values = two_hot_encoder.logits_to_scalar_value(encoded)
|
||||||
|
|
||||||
|
assert torch.allclose(recon_values, values, atol = 1e-6)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user