add nested tensor way for getting log prob of multiple discrete actions
This commit is contained in:
parent
01bf70e18a
commit
8a73a27fc7
@ -8,6 +8,7 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nested import nested_tensor
|
||||
from torch.distributions import Normal
|
||||
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
||||
@ -434,7 +435,10 @@ class ActionEmbedder(Module):
|
||||
continuous_targets = None, # (... na)
|
||||
discrete_action_types = None, # (na)
|
||||
continuous_action_types = None, # (na)
|
||||
parallel_discrete_calc = None
|
||||
):
|
||||
parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
|
||||
|
||||
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
|
||||
|
||||
# discrete
|
||||
@ -442,17 +446,57 @@ class ActionEmbedder(Module):
|
||||
discrete_log_probs = None
|
||||
|
||||
if exists(discrete_targets):
|
||||
discrete_log_probs = []
|
||||
|
||||
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
|
||||
if parallel_discrete_calc:
|
||||
# use nested tensors
|
||||
|
||||
one_discrete_log_probs = one_discrete_action_logit.log_softmax(dim = -1)
|
||||
one_discrete_target = rearrange(one_discrete_target, '... -> ... 1')
|
||||
jagged_dims = tuple(t.shape[-1] for t in discrete_action_logits)
|
||||
|
||||
log_prob = one_discrete_log_probs.gather(-1, one_discrete_target)
|
||||
discrete_log_probs.append(log_prob)
|
||||
discrete_action_logits = cat(discrete_action_logits, dim = -1)
|
||||
|
||||
discrete_log_probs = cat(discrete_log_probs, dim = -1)
|
||||
discrete_action_logits, inverse_pack_lead_dims = pack_one(discrete_action_logits, '* l')
|
||||
batch = discrete_action_logits.shape[0]
|
||||
|
||||
discrete_action_logits = rearrange(discrete_action_logits, 'b l -> (b l)')
|
||||
|
||||
# to nested tensor
|
||||
|
||||
nested_logits = nested_tensor(discrete_action_logits.split(jagged_dims * batch), layout = torch.jagged, device = self.device, requires_grad = True)
|
||||
|
||||
log_probs = log(nested_logits.softmax(dim = -1))
|
||||
|
||||
# back to regular tensor
|
||||
|
||||
log_probs = cat(log_probs.unbind())
|
||||
log_probs = rearrange(log_probs, '(b l) -> b l', b = batch)
|
||||
|
||||
log_probs = inverse_pack_lead_dims(log_probs)
|
||||
|
||||
# get indices to gather
|
||||
|
||||
discrete_action_types = default(discrete_action_types, self.default_discrete_action_types)
|
||||
|
||||
num_discrete_actions = self.num_discrete_actions[discrete_action_types]
|
||||
|
||||
offset = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
|
||||
log_prob_indices = discrete_targets + offset
|
||||
|
||||
# gather
|
||||
|
||||
discrete_log_probs = log_probs.gather(-1, log_prob_indices)
|
||||
|
||||
else:
|
||||
discrete_log_probs = []
|
||||
|
||||
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
|
||||
|
||||
one_discrete_log_probs = one_discrete_action_logit.log_softmax(dim = -1)
|
||||
one_discrete_target = rearrange(one_discrete_target, '... -> ... 1')
|
||||
|
||||
log_prob = one_discrete_log_probs.gather(-1, one_discrete_target)
|
||||
discrete_log_probs.append(log_prob)
|
||||
|
||||
discrete_log_probs = cat(discrete_log_probs, dim = -1)
|
||||
|
||||
# continuous
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.14"
|
||||
version = "0.0.15"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -296,8 +296,18 @@ def test_action_embedder():
|
||||
discrete_log_probs, continuous_log_probs = embedder.log_probs(
|
||||
action_embed,
|
||||
discrete_targets = discrete_actions,
|
||||
continuous_targets = continuous_actions
|
||||
continuous_targets = continuous_actions,
|
||||
parallel_discrete_calc = False
|
||||
)
|
||||
|
||||
assert discrete_log_probs.shape == (2, 3, 2)
|
||||
assert continuous_log_probs.shape == (2, 3, 2)
|
||||
|
||||
parallel_discrete_log_probs, _ = embedder.log_probs(
|
||||
action_embed,
|
||||
discrete_targets = discrete_actions,
|
||||
continuous_targets = continuous_actions,
|
||||
parallel_discrete_calc = True
|
||||
)
|
||||
|
||||
assert torch.allclose(discrete_log_probs, parallel_discrete_log_probs, atol = 1e-5)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user