complete multi-token prediction for actions, tackle loss balancing another day

This commit is contained in:
lucidrains 2025-10-18 10:23:14 -07:00
parent bc629d78b1
commit b6aa19f31e
3 changed files with 105 additions and 16 deletions

View File

@ -124,6 +124,15 @@ def safe_cat(tensors, dim):
return cat(tensors, dim = dim)
def safe_squeeze_first(t):
if not exists(t):
return None
if t.shape[0] != 1:
return t
return rearrange(t, '1 ... -> ...')
def gumbel_noise(t):
noise = torch.rand_like(t)
return -log(-log(noise))
@ -365,7 +374,9 @@ class ActionEmbedder(Module):
num_continuous_actions = 0,
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None,
can_unembed = False,
unembed_dim = None
unembed_dim = None,
num_unembed_preds = 1,
squeeze_unembed_preds = True # will auto-squeeze if prediction is just 1
):
super().__init__()
@ -403,11 +414,14 @@ class ActionEmbedder(Module):
self.can_unembed = can_unembed
self.num_unembed_preds = num_unembed_preds
self.squeeze_unembed_preds = squeeze_unembed_preds
if not can_unembed:
return
unembed_dim = default(unembed_dim, dim)
self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, unembed_dim) * 1e-2)
self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, num_unembed_preds, unembed_dim) * 1e-2)
discrete_action_index = arange(total_discrete_actions)
@ -421,7 +435,7 @@ class ActionEmbedder(Module):
self.register_buffer('discrete_action_mask', discrete_action_mask, persistent = False)
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2)
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, num_unembed_preds, unembed_dim, 2) * 1e-2)
def embed_parameters(self):
return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
@ -454,12 +468,26 @@ class ActionEmbedder(Module):
embeds, # (... d)
discrete_action_types = None, # (na)
continuous_action_types = None, # (na)
return_split_discrete = False
return_split_discrete = False,
pred_head_index: int | Tensor | None = None
): # (... discrete_na), (... continuous_na 2)
device = embeds.device
assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init'
# handle only one prediction head during inference
if exists(pred_head_index) and isinstance(pred_head_index, int):
pred_head_index = tensor(pred_head_index, device = device)
# if pred_head_index given as a solo int, just assume we want to squeeze out the prediction head dimension
squeeze_one_pred_head = exists(pred_head_index) and pred_head_index.ndim == 0
# get action types
discrete_action_types, continuous_action_types = tuple(self.cast_action_types(t) for t in (discrete_action_types, continuous_action_types))
# discrete actions
@ -475,7 +503,13 @@ class ActionEmbedder(Module):
discrete_action_unembed = discrete_action_unembed[discrete_action_mask]
discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na d -> ... na')
if exists(pred_head_index):
discrete_action_unembed = discrete_action_unembed.index_select(1, pred_head_index)
discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na mtp d -> mtp ... na')
if self.squeeze_unembed_preds or squeeze_one_pred_head:
discrete_action_logits = safe_squeeze_first(discrete_action_logits)
# whether to split the discrete action logits by the number of actions per action type
@ -496,7 +530,15 @@ class ActionEmbedder(Module):
if exists(continuous_action_types):
continuous_action_unembed = continuous_action_unembed[continuous_action_types]
continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na d two -> ... na two')
if isinstance(pred_head_index, int):
continuous_action_unembed = continuous_action_unembed[:, pred_head_index]
else:
continuous_action_unembed = continuous_action_unembed.index_select(1, pred_head_index)
continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na mtp d two -> mtp ... na two')
if self.squeeze_unembed_preds or squeeze_one_pred_head:
continuous_action_mean_log_var = safe_squeeze_first(continuous_action_mean_log_var)
return discrete_action_logits, continuous_action_mean_log_var
@ -506,11 +548,13 @@ class ActionEmbedder(Module):
discrete_temperature = 1.,
continuous_temperature = 1.,
inverse_norm_continuous = None,
pred_head_index: int | Tensor | None = None,
squeeze = True,
**kwargs
):
inverse_norm_continuous = default(inverse_norm_continuous, self.continuous_need_norm)
discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, **kwargs)
discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, pred_head_index = pred_head_index, **kwargs)
sampled_discrete = sampled_continuous = None
@ -543,12 +587,13 @@ class ActionEmbedder(Module):
continuous_targets = None, # (... na)
discrete_action_types = None, # (na)
continuous_action_types = None, # (na)
pred_head_index: int | Tensor | None = None,
parallel_discrete_calc = None,
return_entropies = False
):
parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, pred_head_index = pred_head_index, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
# discrete
@ -1651,7 +1696,9 @@ class DynamicsWorldModel(Module):
num_continuous_actions = num_continuous_actions,
continuous_norm_stats = continuous_norm_stats,
can_unembed = True,
unembed_dim = dim * 4
unembed_dim = dim * 4,
num_unembed_preds = multi_token_pred_len,
squeeze_unembed_preds = False
)
self.behavior_clone_weight = behavior_clone_weight
@ -1790,7 +1837,7 @@ class DynamicsWorldModel(Module):
policy_embed = self.policy_head(agent_embed)
log_probs, entropies = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
log_probs, entropies = self.action_embedder.log_probs(policy_embed, pred_head_index = 0, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
# concat discrete and continuous actions into one for optimizing
@ -1981,7 +2028,7 @@ class DynamicsWorldModel(Module):
policy_embed = self.policy_head(one_agent_embed)
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed)
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
decoded_continuous_actions = safe_cat((decoded_continuous_actions, sampled_continuous_actions), dim = 1)
@ -1989,6 +2036,7 @@ class DynamicsWorldModel(Module):
if return_log_probs_and_values:
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
policy_embed,
pred_head_index = 0,
discrete_targets = sampled_discrete_actions,
continuous_targets = sampled_continuous_actions,
)
@ -2424,6 +2472,7 @@ class DynamicsWorldModel(Module):
if (
self.num_agents == 1 and
add_autoregressive_action_loss and
time > 1,
(exists(discrete_actions) or exists(continuous_actions))
):
assert self.action_embedder.has_actions
@ -2433,17 +2482,35 @@ class DynamicsWorldModel(Module):
agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d')
policy_embed = self.policy_head(agent_tokens[:, :-1])
# constitute multi token prediction targets
discrete_action_targets = continuous_action_targets = None
if exists(discrete_actions):
discrete_action_targets, discrete_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
discrete_action_targets = rearrange(discrete_action_targets, 'b t mtp ... -> mtp b t ...')
discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
if exists(continuous_actions):
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
policy_embed,
discrete_targets = discrete_actions[:, 1:] if exists(discrete_actions) else None,
continuous_targets = continuous_actions[:, 1:] if exists(continuous_actions) else None
discrete_targets = discrete_action_targets if exists(discrete_actions) else None,
continuous_targets = continuous_action_targets if exists(continuous_actions) else None
)
if exists(discrete_log_probs):
behavior_clone_loss = behavior_clone_loss - discrete_log_probs.sum(dim = -1).mean()
discrete_log_probs = discrete_log_probs.masked_fill(discrete_mask[..., None], 0.)
behavior_clone_loss = behavior_clone_loss - reduce(discrete_log_probs, 'mtp b t na -> b t', 'sum').mean()
if exists(continuous_log_probs):
behavior_clone_loss = behavior_clone_loss - continuous_log_probs.sum(dim = -1).mean()
continuous_log_probs = continuous_log_probs.masked_fill(continuous_mask[..., None], 0.)
behavior_clone_loss = behavior_clone_loss - reduce(continuous_log_probs, 'mtp b t na -> b t', 'sum').mean()
# gather losses

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.37"
version = "0.0.38"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -401,3 +401,25 @@ def test_mtp():
assert action_targets.shape == (3, 15, 3, 2)
assert mask.shape == (3, 15, 3)
from dreamer4.dreamer4 import ActionEmbedder
embedder = ActionEmbedder(
512,
num_discrete_actions = (4, 4),
num_continuous_actions = 2,
can_unembed = True,
num_unembed_preds = 8
)
discrete_actions = torch.randint(0, 4, (2, 3, 2))
continuous_actions = torch.randn(2, 3, 2)
action_embed = torch.randn(2, 16, 512)
discrete_logits, continuous_logits = embedder.unembed(action_embed)
assert discrete_logits.shape == (8, 2, 16, 8)
discrete_logits, continuous_logits = embedder.unembed(action_embed, pred_head_index = 0)
assert discrete_logits.shape == (2, 16, 8)