diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index a141007..674e7cb 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -505,8 +505,11 @@ class ActionEmbedder(Module): embed, discrete_temperature = 1., continuous_temperature = 1., + inverse_norm_continuous = None, **kwargs ): + inverse_norm_continuous = default(inverse_norm_continuous, self.continuous_need_norm) + discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, **kwargs) sampled_discrete = sampled_continuous = None @@ -525,6 +528,12 @@ class ActionEmbedder(Module): sampled_continuous = mean + std * torch.randn_like(mean) * continuous_temperature + # maybe inverse norm + + if inverse_norm_continuous: + norm_mean, norm_std = self.continuous_norm_stats.unbind(dim = -1) + sampled_continuous = (sampled_continuous * norm_std) + norm_mean + return sampled_discrete, sampled_continuous def log_probs( diff --git a/pyproject.toml b/pyproject.toml index af698ba..bc67a5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.36" +version = "0.0.37" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }