add nested tensor way for getting log prob of multiple discrete actions
This commit is contained in:
parent
01bf70e18a
commit
8a73a27fc7
@ -8,6 +8,7 @@ from functools import partial
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.nested import nested_tensor
|
||||||
from torch.distributions import Normal
|
from torch.distributions import Normal
|
||||||
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||||
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
||||||
@ -434,7 +435,10 @@ class ActionEmbedder(Module):
|
|||||||
continuous_targets = None, # (... na)
|
continuous_targets = None, # (... na)
|
||||||
discrete_action_types = None, # (na)
|
discrete_action_types = None, # (na)
|
||||||
continuous_action_types = None, # (na)
|
continuous_action_types = None, # (na)
|
||||||
|
parallel_discrete_calc = None
|
||||||
):
|
):
|
||||||
|
parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
|
||||||
|
|
||||||
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
|
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
|
||||||
|
|
||||||
# discrete
|
# discrete
|
||||||
@ -442,6 +446,46 @@ class ActionEmbedder(Module):
|
|||||||
discrete_log_probs = None
|
discrete_log_probs = None
|
||||||
|
|
||||||
if exists(discrete_targets):
|
if exists(discrete_targets):
|
||||||
|
|
||||||
|
if parallel_discrete_calc:
|
||||||
|
# use nested tensors
|
||||||
|
|
||||||
|
jagged_dims = tuple(t.shape[-1] for t in discrete_action_logits)
|
||||||
|
|
||||||
|
discrete_action_logits = cat(discrete_action_logits, dim = -1)
|
||||||
|
|
||||||
|
discrete_action_logits, inverse_pack_lead_dims = pack_one(discrete_action_logits, '* l')
|
||||||
|
batch = discrete_action_logits.shape[0]
|
||||||
|
|
||||||
|
discrete_action_logits = rearrange(discrete_action_logits, 'b l -> (b l)')
|
||||||
|
|
||||||
|
# to nested tensor
|
||||||
|
|
||||||
|
nested_logits = nested_tensor(discrete_action_logits.split(jagged_dims * batch), layout = torch.jagged, device = self.device, requires_grad = True)
|
||||||
|
|
||||||
|
log_probs = log(nested_logits.softmax(dim = -1))
|
||||||
|
|
||||||
|
# back to regular tensor
|
||||||
|
|
||||||
|
log_probs = cat(log_probs.unbind())
|
||||||
|
log_probs = rearrange(log_probs, '(b l) -> b l', b = batch)
|
||||||
|
|
||||||
|
log_probs = inverse_pack_lead_dims(log_probs)
|
||||||
|
|
||||||
|
# get indices to gather
|
||||||
|
|
||||||
|
discrete_action_types = default(discrete_action_types, self.default_discrete_action_types)
|
||||||
|
|
||||||
|
num_discrete_actions = self.num_discrete_actions[discrete_action_types]
|
||||||
|
|
||||||
|
offset = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
|
||||||
|
log_prob_indices = discrete_targets + offset
|
||||||
|
|
||||||
|
# gather
|
||||||
|
|
||||||
|
discrete_log_probs = log_probs.gather(-1, log_prob_indices)
|
||||||
|
|
||||||
|
else:
|
||||||
discrete_log_probs = []
|
discrete_log_probs = []
|
||||||
|
|
||||||
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
|
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.14"
|
version = "0.0.15"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -296,8 +296,18 @@ def test_action_embedder():
|
|||||||
discrete_log_probs, continuous_log_probs = embedder.log_probs(
|
discrete_log_probs, continuous_log_probs = embedder.log_probs(
|
||||||
action_embed,
|
action_embed,
|
||||||
discrete_targets = discrete_actions,
|
discrete_targets = discrete_actions,
|
||||||
continuous_targets = continuous_actions
|
continuous_targets = continuous_actions,
|
||||||
|
parallel_discrete_calc = False
|
||||||
)
|
)
|
||||||
|
|
||||||
assert discrete_log_probs.shape == (2, 3, 2)
|
assert discrete_log_probs.shape == (2, 3, 2)
|
||||||
assert continuous_log_probs.shape == (2, 3, 2)
|
assert continuous_log_probs.shape == (2, 3, 2)
|
||||||
|
|
||||||
|
parallel_discrete_log_probs, _ = embedder.log_probs(
|
||||||
|
action_embed,
|
||||||
|
discrete_targets = discrete_actions,
|
||||||
|
continuous_targets = continuous_actions,
|
||||||
|
parallel_discrete_calc = True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(discrete_log_probs, parallel_discrete_log_probs, atol = 1e-5)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user