inverse norm for continuous actions when sampling
This commit is contained in:
parent
0ee475d2df
commit
bc629d78b1
@ -505,8 +505,11 @@ class ActionEmbedder(Module):
|
||||
embed,
|
||||
discrete_temperature = 1.,
|
||||
continuous_temperature = 1.,
|
||||
inverse_norm_continuous = None,
|
||||
**kwargs
|
||||
):
|
||||
inverse_norm_continuous = default(inverse_norm_continuous, self.continuous_need_norm)
|
||||
|
||||
discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, **kwargs)
|
||||
|
||||
sampled_discrete = sampled_continuous = None
|
||||
@ -525,6 +528,12 @@ class ActionEmbedder(Module):
|
||||
|
||||
sampled_continuous = mean + std * torch.randn_like(mean) * continuous_temperature
|
||||
|
||||
# maybe inverse norm
|
||||
|
||||
if inverse_norm_continuous:
|
||||
norm_mean, norm_std = self.continuous_norm_stats.unbind(dim = -1)
|
||||
sampled_continuous = (sampled_continuous * norm_std) + norm_mean
|
||||
|
||||
return sampled_discrete, sampled_continuous
|
||||
|
||||
def log_probs(
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.36"
|
||||
version = "0.0.37"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user