inverse norm for continuous actions when sampling

This commit is contained in:
lucidrains 2025-10-18 08:55:04 -07:00
parent 0ee475d2df
commit bc629d78b1
2 changed files with 10 additions and 1 deletions

View File

@ -505,8 +505,11 @@ class ActionEmbedder(Module):
embed, embed,
discrete_temperature = 1., discrete_temperature = 1.,
continuous_temperature = 1., continuous_temperature = 1.,
inverse_norm_continuous = None,
**kwargs **kwargs
): ):
inverse_norm_continuous = default(inverse_norm_continuous, self.continuous_need_norm)
discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, **kwargs) discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, **kwargs)
sampled_discrete = sampled_continuous = None sampled_discrete = sampled_continuous = None
@ -525,6 +528,12 @@ class ActionEmbedder(Module):
sampled_continuous = mean + std * torch.randn_like(mean) * continuous_temperature sampled_continuous = mean + std * torch.randn_like(mean) * continuous_temperature
# maybe inverse norm
if inverse_norm_continuous:
norm_mean, norm_std = self.continuous_norm_stats.unbind(dim = -1)
sampled_continuous = (sampled_continuous * norm_std) + norm_mean
return sampled_discrete, sampled_continuous return sampled_discrete, sampled_continuous
def log_probs( def log_probs(

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.36" version = "0.0.37"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }