sampling actions

This commit is contained in:
lucidrains 2025-10-12 11:27:12 -07:00
parent c5e64ff4ce
commit 9c78962736
3 changed files with 49 additions and 1 deletions

View File

@ -90,6 +90,20 @@ def is_power_two(num):
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
def gumbel_noise(t):
noise = torch.rand_like(t)
return -log(-log(noise))
def gumbel_sample(
t,
temperature = 1.,
dim = -1,
keepdim = False,
eps = 1e-10
):
noised = (t / max(temperature, eps)) + gumbel_noise(t)
return noised.argmax(dim = dim, keepdim = keepdim)
def pack_one(t, pattern):
packed, packed_shape = pack([t], pattern)
@ -428,6 +442,33 @@ class ActionEmbedder(Module):
return discrete_action_logits, continuous_action_mean_log_var
def sample(
self,
embed,
discrete_temperature = 1.,
continuous_temperature = 1.,
**kwargs
):
discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, **kwargs)
sampled_discrete = sampled_continuous = None
if exists(discrete_logits):
sampled_discrete = []
for one_discrete_logits in discrete_logits:
sampled_discrete.append(gumbel_sample(one_discrete_logits, temperature = discrete_temperature, keepdim = True))
sampled_discrete = cat(sampled_discrete, dim = -1)
if exists(continuous_mean_log_var):
mean, log_var = continuous_mean_log_var.unbind(dim = -1)
std = (0.5 * log_var).exp()
sampled_continuous = mean + std * torch.randn_like(mean) * continuous_temperature
return sampled_discrete, sampled_continuous
def log_probs(
self,
embeds, # (... d)

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.16"
version = "0.0.17"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -288,6 +288,13 @@ def test_action_embedder():
assert discrete_logits.shape == (2, 3, 4)
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
# sample actions
sampled_discrete_actions, sampled_continuous_actions = embedder.sample(action_embed, discrete_action_types = 1, continuous_action_types = 0)
assert sampled_discrete_actions.shape == (2, 3, 1)
assert sampled_continuous_actions.shape == (2, 3, 1)
# log probs
assert discrete_logits.shape == (2, 3, 4)