add nested tensor way for getting log prob of multiple discrete actions

This commit is contained in:
lucidrains 2025-10-11 10:53:24 -07:00
parent 01bf70e18a
commit 8a73a27fc7
3 changed files with 63 additions and 9 deletions

View File

@ -8,6 +8,7 @@ from functools import partial
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nested import nested_tensor
from torch.distributions import Normal from torch.distributions import Normal
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
@ -434,7 +435,10 @@ class ActionEmbedder(Module):
continuous_targets = None, # (... na) continuous_targets = None, # (... na)
discrete_action_types = None, # (na) discrete_action_types = None, # (na)
continuous_action_types = None, # (na) continuous_action_types = None, # (na)
parallel_discrete_calc = None
): ):
parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True) discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
# discrete # discrete
@ -442,6 +446,46 @@ class ActionEmbedder(Module):
discrete_log_probs = None discrete_log_probs = None
if exists(discrete_targets): if exists(discrete_targets):
if parallel_discrete_calc:
# use nested tensors
jagged_dims = tuple(t.shape[-1] for t in discrete_action_logits)
discrete_action_logits = cat(discrete_action_logits, dim = -1)
discrete_action_logits, inverse_pack_lead_dims = pack_one(discrete_action_logits, '* l')
batch = discrete_action_logits.shape[0]
discrete_action_logits = rearrange(discrete_action_logits, 'b l -> (b l)')
# to nested tensor
nested_logits = nested_tensor(discrete_action_logits.split(jagged_dims * batch), layout = torch.jagged, device = self.device, requires_grad = True)
log_probs = log(nested_logits.softmax(dim = -1))
# back to regular tensor
log_probs = cat(log_probs.unbind())
log_probs = rearrange(log_probs, '(b l) -> b l', b = batch)
log_probs = inverse_pack_lead_dims(log_probs)
# get indices to gather
discrete_action_types = default(discrete_action_types, self.default_discrete_action_types)
num_discrete_actions = self.num_discrete_actions[discrete_action_types]
offset = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
log_prob_indices = discrete_targets + offset
# gather
discrete_log_probs = log_probs.gather(-1, log_prob_indices)
else:
discrete_log_probs = [] discrete_log_probs = []
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)): for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.14" version = "0.0.15"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -296,8 +296,18 @@ def test_action_embedder():
discrete_log_probs, continuous_log_probs = embedder.log_probs( discrete_log_probs, continuous_log_probs = embedder.log_probs(
action_embed, action_embed,
discrete_targets = discrete_actions, discrete_targets = discrete_actions,
continuous_targets = continuous_actions continuous_targets = continuous_actions,
parallel_discrete_calc = False
) )
assert discrete_log_probs.shape == (2, 3, 2) assert discrete_log_probs.shape == (2, 3, 2)
assert continuous_log_probs.shape == (2, 3, 2) assert continuous_log_probs.shape == (2, 3, 2)
parallel_discrete_log_probs, _ = embedder.log_probs(
action_embed,
discrete_targets = discrete_actions,
continuous_targets = continuous_actions,
parallel_discrete_calc = True
)
assert torch.allclose(discrete_log_probs, parallel_discrete_log_probs, atol = 1e-5)