diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 9b5a693..1c9e5e1 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -90,6 +90,20 @@ def is_power_two(num): def log(t, eps = 1e-20): return t.clamp(min = eps).log() +def gumbel_noise(t): + noise = torch.rand_like(t) + return -log(-log(noise)) + +def gumbel_sample( + t, + temperature = 1., + dim = -1, + keepdim = False, + eps = 1e-10 +): + noised = (t / max(temperature, eps)) + gumbel_noise(t) + return noised.argmax(dim = dim, keepdim = keepdim) + def pack_one(t, pattern): packed, packed_shape = pack([t], pattern) @@ -428,6 +442,33 @@ class ActionEmbedder(Module): return discrete_action_logits, continuous_action_mean_log_var + def sample( + self, + embed, + discrete_temperature = 1., + continuous_temperature = 1., + **kwargs + ): + discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, **kwargs) + + sampled_discrete = sampled_continuous = None + + if exists(discrete_logits): + sampled_discrete = [] + + for one_discrete_logits in discrete_logits: + sampled_discrete.append(gumbel_sample(one_discrete_logits, temperature = discrete_temperature, keepdim = True)) + + sampled_discrete = cat(sampled_discrete, dim = -1) + + if exists(continuous_mean_log_var): + mean, log_var = continuous_mean_log_var.unbind(dim = -1) + std = (0.5 * log_var).exp() + + sampled_continuous = mean + std * torch.randn_like(mean) * continuous_temperature + + return sampled_discrete, sampled_continuous + def log_probs( self, embeds, # (... d) diff --git a/pyproject.toml b/pyproject.toml index 32e3c0a..8832486 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.16" +version = "0.0.17" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 2ef8864..aa2b39f 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -288,6 +288,13 @@ def test_action_embedder(): assert discrete_logits.shape == (2, 3, 4) assert continuous_mean_log_var.shape == (2, 3, 1, 2) + # sample actions + + sampled_discrete_actions, sampled_continuous_actions = embedder.sample(action_embed, discrete_action_types = 1, continuous_action_types = 0) + + assert sampled_discrete_actions.shape == (2, 3, 1) + assert sampled_continuous_actions.shape == (2, 3, 1) + # log probs assert discrete_logits.shape == (2, 3, 4)