add nested tensor way for getting log prob of multiple discrete actions

This commit is contained in:
lucidrains 2025-10-11 10:53:24 -07:00
parent 01bf70e18a
commit 8a73a27fc7
3 changed files with 63 additions and 9 deletions

View File

@ -8,6 +8,7 @@ from functools import partial
import torch
import torch.nn.functional as F
from torch.nested import nested_tensor
from torch.distributions import Normal
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
@ -434,7 +435,10 @@ class ActionEmbedder(Module):
continuous_targets = None, # (... na)
discrete_action_types = None, # (na)
continuous_action_types = None, # (na)
parallel_discrete_calc = None
):
parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
# discrete
@ -442,17 +446,57 @@ class ActionEmbedder(Module):
discrete_log_probs = None
if exists(discrete_targets):
discrete_log_probs = []
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
if parallel_discrete_calc:
# use nested tensors
one_discrete_log_probs = one_discrete_action_logit.log_softmax(dim = -1)
one_discrete_target = rearrange(one_discrete_target, '... -> ... 1')
jagged_dims = tuple(t.shape[-1] for t in discrete_action_logits)
log_prob = one_discrete_log_probs.gather(-1, one_discrete_target)
discrete_log_probs.append(log_prob)
discrete_action_logits = cat(discrete_action_logits, dim = -1)
discrete_log_probs = cat(discrete_log_probs, dim = -1)
discrete_action_logits, inverse_pack_lead_dims = pack_one(discrete_action_logits, '* l')
batch = discrete_action_logits.shape[0]
discrete_action_logits = rearrange(discrete_action_logits, 'b l -> (b l)')
# to nested tensor
nested_logits = nested_tensor(discrete_action_logits.split(jagged_dims * batch), layout = torch.jagged, device = self.device, requires_grad = True)
log_probs = log(nested_logits.softmax(dim = -1))
# back to regular tensor
log_probs = cat(log_probs.unbind())
log_probs = rearrange(log_probs, '(b l) -> b l', b = batch)
log_probs = inverse_pack_lead_dims(log_probs)
# get indices to gather
discrete_action_types = default(discrete_action_types, self.default_discrete_action_types)
num_discrete_actions = self.num_discrete_actions[discrete_action_types]
offset = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
log_prob_indices = discrete_targets + offset
# gather
discrete_log_probs = log_probs.gather(-1, log_prob_indices)
else:
discrete_log_probs = []
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
one_discrete_log_probs = one_discrete_action_logit.log_softmax(dim = -1)
one_discrete_target = rearrange(one_discrete_target, '... -> ... 1')
log_prob = one_discrete_log_probs.gather(-1, one_discrete_target)
discrete_log_probs.append(log_prob)
discrete_log_probs = cat(discrete_log_probs, dim = -1)
# continuous

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.14"
version = "0.0.15"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -296,8 +296,18 @@ def test_action_embedder():
discrete_log_probs, continuous_log_probs = embedder.log_probs(
action_embed,
discrete_targets = discrete_actions,
continuous_targets = continuous_actions
continuous_targets = continuous_actions,
parallel_discrete_calc = False
)
assert discrete_log_probs.shape == (2, 3, 2)
assert continuous_log_probs.shape == (2, 3, 2)
parallel_discrete_log_probs, _ = embedder.log_probs(
action_embed,
discrete_targets = discrete_actions,
continuous_targets = continuous_actions,
parallel_discrete_calc = True
)
assert torch.allclose(discrete_log_probs, parallel_discrete_log_probs, atol = 1e-5)