From 8a73a27fc725fafc156a035e3070c48c7854b9b7 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 11 Oct 2025 10:53:24 -0700 Subject: [PATCH] add nested tensor way for getting log prob of multiple discrete actions --- dreamer4/dreamer4.py | 58 +++++++++++++++++++++++++++++++++++++------ pyproject.toml | 2 +- tests/test_dreamer.py | 12 ++++++++- 3 files changed, 63 insertions(+), 9 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index fb598bf..6ea10ea 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -8,6 +8,7 @@ from functools import partial import torch import torch.nn.functional as F +from torch.nested import nested_tensor from torch.distributions import Normal from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange @@ -434,7 +435,10 @@ class ActionEmbedder(Module): continuous_targets = None, # (... na) discrete_action_types = None, # (na) continuous_action_types = None, # (na) + parallel_discrete_calc = None ): + parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1) + discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True) # discrete @@ -442,17 +446,57 @@ class ActionEmbedder(Module): discrete_log_probs = None if exists(discrete_targets): - discrete_log_probs = [] - for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)): + if parallel_discrete_calc: + # use nested tensors - one_discrete_log_probs = one_discrete_action_logit.log_softmax(dim = -1) - one_discrete_target = rearrange(one_discrete_target, '... -> ... 1') + jagged_dims = tuple(t.shape[-1] for t in discrete_action_logits) - log_prob = one_discrete_log_probs.gather(-1, one_discrete_target) - discrete_log_probs.append(log_prob) + discrete_action_logits = cat(discrete_action_logits, dim = -1) - discrete_log_probs = cat(discrete_log_probs, dim = -1) + discrete_action_logits, inverse_pack_lead_dims = pack_one(discrete_action_logits, '* l') + batch = discrete_action_logits.shape[0] + + discrete_action_logits = rearrange(discrete_action_logits, 'b l -> (b l)') + + # to nested tensor + + nested_logits = nested_tensor(discrete_action_logits.split(jagged_dims * batch), layout = torch.jagged, device = self.device, requires_grad = True) + + log_probs = log(nested_logits.softmax(dim = -1)) + + # back to regular tensor + + log_probs = cat(log_probs.unbind()) + log_probs = rearrange(log_probs, '(b l) -> b l', b = batch) + + log_probs = inverse_pack_lead_dims(log_probs) + + # get indices to gather + + discrete_action_types = default(discrete_action_types, self.default_discrete_action_types) + + num_discrete_actions = self.num_discrete_actions[discrete_action_types] + + offset = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0) + log_prob_indices = discrete_targets + offset + + # gather + + discrete_log_probs = log_probs.gather(-1, log_prob_indices) + + else: + discrete_log_probs = [] + + for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)): + + one_discrete_log_probs = one_discrete_action_logit.log_softmax(dim = -1) + one_discrete_target = rearrange(one_discrete_target, '... -> ... 1') + + log_prob = one_discrete_log_probs.gather(-1, one_discrete_target) + discrete_log_probs.append(log_prob) + + discrete_log_probs = cat(discrete_log_probs, dim = -1) # continuous diff --git a/pyproject.toml b/pyproject.toml index 05d7c0a..f073f37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.14" +version = "0.0.15" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index d8f40f8..2ef8864 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -296,8 +296,18 @@ def test_action_embedder(): discrete_log_probs, continuous_log_probs = embedder.log_probs( action_embed, discrete_targets = discrete_actions, - continuous_targets = continuous_actions + continuous_targets = continuous_actions, + parallel_discrete_calc = False ) assert discrete_log_probs.shape == (2, 3, 2) assert continuous_log_probs.shape == (2, 3, 2) + + parallel_discrete_log_probs, _ = embedder.log_probs( + action_embed, + discrete_targets = discrete_actions, + continuous_targets = continuous_actions, + parallel_discrete_calc = True + ) + + assert torch.allclose(discrete_log_probs, parallel_discrete_log_probs, atol = 1e-5)