sampling actions
This commit is contained in:
parent
c5e64ff4ce
commit
9c78962736
@ -90,6 +90,20 @@ def is_power_two(num):
|
||||
def log(t, eps = 1e-20):
|
||||
return t.clamp(min = eps).log()
|
||||
|
||||
def gumbel_noise(t):
|
||||
noise = torch.rand_like(t)
|
||||
return -log(-log(noise))
|
||||
|
||||
def gumbel_sample(
|
||||
t,
|
||||
temperature = 1.,
|
||||
dim = -1,
|
||||
keepdim = False,
|
||||
eps = 1e-10
|
||||
):
|
||||
noised = (t / max(temperature, eps)) + gumbel_noise(t)
|
||||
return noised.argmax(dim = dim, keepdim = keepdim)
|
||||
|
||||
def pack_one(t, pattern):
|
||||
packed, packed_shape = pack([t], pattern)
|
||||
|
||||
@ -428,6 +442,33 @@ class ActionEmbedder(Module):
|
||||
|
||||
return discrete_action_logits, continuous_action_mean_log_var
|
||||
|
||||
def sample(
|
||||
self,
|
||||
embed,
|
||||
discrete_temperature = 1.,
|
||||
continuous_temperature = 1.,
|
||||
**kwargs
|
||||
):
|
||||
discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, **kwargs)
|
||||
|
||||
sampled_discrete = sampled_continuous = None
|
||||
|
||||
if exists(discrete_logits):
|
||||
sampled_discrete = []
|
||||
|
||||
for one_discrete_logits in discrete_logits:
|
||||
sampled_discrete.append(gumbel_sample(one_discrete_logits, temperature = discrete_temperature, keepdim = True))
|
||||
|
||||
sampled_discrete = cat(sampled_discrete, dim = -1)
|
||||
|
||||
if exists(continuous_mean_log_var):
|
||||
mean, log_var = continuous_mean_log_var.unbind(dim = -1)
|
||||
std = (0.5 * log_var).exp()
|
||||
|
||||
sampled_continuous = mean + std * torch.randn_like(mean) * continuous_temperature
|
||||
|
||||
return sampled_discrete, sampled_continuous
|
||||
|
||||
def log_probs(
|
||||
self,
|
||||
embeds, # (... d)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.16"
|
||||
version = "0.0.17"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -288,6 +288,13 @@ def test_action_embedder():
|
||||
assert discrete_logits.shape == (2, 3, 4)
|
||||
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
|
||||
|
||||
# sample actions
|
||||
|
||||
sampled_discrete_actions, sampled_continuous_actions = embedder.sample(action_embed, discrete_action_types = 1, continuous_action_types = 0)
|
||||
|
||||
assert sampled_discrete_actions.shape == (2, 3, 1)
|
||||
assert sampled_continuous_actions.shape == (2, 3, 1)
|
||||
|
||||
# log probs
|
||||
|
||||
assert discrete_logits.shape == (2, 3, 4)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user