From b6aa19f31e3703e40fbdc9c6a1ce0ab2a49caac5 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 18 Oct 2025 10:23:14 -0700 Subject: [PATCH] complete multi-token prediction for actions, tackle loss balancing another day --- dreamer4/dreamer4.py | 97 ++++++++++++++++++++++++++++++++++++------- pyproject.toml | 2 +- tests/test_dreamer.py | 22 ++++++++++ 3 files changed, 105 insertions(+), 16 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 674e7cb..0d5f30c 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -124,6 +124,15 @@ def safe_cat(tensors, dim): return cat(tensors, dim = dim) +def safe_squeeze_first(t): + if not exists(t): + return None + + if t.shape[0] != 1: + return t + + return rearrange(t, '1 ... -> ...') + def gumbel_noise(t): noise = torch.rand_like(t) return -log(-log(noise)) @@ -365,7 +374,9 @@ class ActionEmbedder(Module): num_continuous_actions = 0, continuous_norm_stats: tuple[tuple[float, float], ...] | None = None, can_unembed = False, - unembed_dim = None + unembed_dim = None, + num_unembed_preds = 1, + squeeze_unembed_preds = True # will auto-squeeze if prediction is just 1 ): super().__init__() @@ -403,11 +414,14 @@ class ActionEmbedder(Module): self.can_unembed = can_unembed + self.num_unembed_preds = num_unembed_preds + self.squeeze_unembed_preds = squeeze_unembed_preds + if not can_unembed: return unembed_dim = default(unembed_dim, dim) - self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, unembed_dim) * 1e-2) + self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, num_unembed_preds, unembed_dim) * 1e-2) discrete_action_index = arange(total_discrete_actions) @@ -421,7 +435,7 @@ class ActionEmbedder(Module): self.register_buffer('discrete_action_mask', discrete_action_mask, persistent = False) - self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2) + self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, num_unembed_preds, unembed_dim, 2) * 1e-2) def embed_parameters(self): return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()]) @@ -454,12 +468,26 @@ class ActionEmbedder(Module): embeds, # (... d) discrete_action_types = None, # (na) continuous_action_types = None, # (na) - return_split_discrete = False + return_split_discrete = False, + pred_head_index: int | Tensor | None = None ): # (... discrete_na), (... continuous_na 2) + device = embeds.device + assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init' + # handle only one prediction head during inference + + if exists(pred_head_index) and isinstance(pred_head_index, int): + pred_head_index = tensor(pred_head_index, device = device) + + # if pred_head_index given as a solo int, just assume we want to squeeze out the prediction head dimension + + squeeze_one_pred_head = exists(pred_head_index) and pred_head_index.ndim == 0 + + # get action types + discrete_action_types, continuous_action_types = tuple(self.cast_action_types(t) for t in (discrete_action_types, continuous_action_types)) # discrete actions @@ -475,7 +503,13 @@ class ActionEmbedder(Module): discrete_action_unembed = discrete_action_unembed[discrete_action_mask] - discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na d -> ... na') + if exists(pred_head_index): + discrete_action_unembed = discrete_action_unembed.index_select(1, pred_head_index) + + discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na mtp d -> mtp ... na') + + if self.squeeze_unembed_preds or squeeze_one_pred_head: + discrete_action_logits = safe_squeeze_first(discrete_action_logits) # whether to split the discrete action logits by the number of actions per action type @@ -496,7 +530,15 @@ class ActionEmbedder(Module): if exists(continuous_action_types): continuous_action_unembed = continuous_action_unembed[continuous_action_types] - continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na d two -> ... na two') + if isinstance(pred_head_index, int): + continuous_action_unembed = continuous_action_unembed[:, pred_head_index] + else: + continuous_action_unembed = continuous_action_unembed.index_select(1, pred_head_index) + + continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na mtp d two -> mtp ... na two') + + if self.squeeze_unembed_preds or squeeze_one_pred_head: + continuous_action_mean_log_var = safe_squeeze_first(continuous_action_mean_log_var) return discrete_action_logits, continuous_action_mean_log_var @@ -506,11 +548,13 @@ class ActionEmbedder(Module): discrete_temperature = 1., continuous_temperature = 1., inverse_norm_continuous = None, + pred_head_index: int | Tensor | None = None, + squeeze = True, **kwargs ): inverse_norm_continuous = default(inverse_norm_continuous, self.continuous_need_norm) - discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, **kwargs) + discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, pred_head_index = pred_head_index, **kwargs) sampled_discrete = sampled_continuous = None @@ -543,12 +587,13 @@ class ActionEmbedder(Module): continuous_targets = None, # (... na) discrete_action_types = None, # (na) continuous_action_types = None, # (na) + pred_head_index: int | Tensor | None = None, parallel_discrete_calc = None, return_entropies = False ): parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1) - discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True) + discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, pred_head_index = pred_head_index, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True) # discrete @@ -1651,7 +1696,9 @@ class DynamicsWorldModel(Module): num_continuous_actions = num_continuous_actions, continuous_norm_stats = continuous_norm_stats, can_unembed = True, - unembed_dim = dim * 4 + unembed_dim = dim * 4, + num_unembed_preds = multi_token_pred_len, + squeeze_unembed_preds = False ) self.behavior_clone_weight = behavior_clone_weight @@ -1790,7 +1837,7 @@ class DynamicsWorldModel(Module): policy_embed = self.policy_head(agent_embed) - log_probs, entropies = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True) + log_probs, entropies = self.action_embedder.log_probs(policy_embed, pred_head_index = 0, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True) # concat discrete and continuous actions into one for optimizing @@ -1981,7 +2028,7 @@ class DynamicsWorldModel(Module): policy_embed = self.policy_head(one_agent_embed) - sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed) + sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True) decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1) decoded_continuous_actions = safe_cat((decoded_continuous_actions, sampled_continuous_actions), dim = 1) @@ -1989,6 +2036,7 @@ class DynamicsWorldModel(Module): if return_log_probs_and_values: discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs( policy_embed, + pred_head_index = 0, discrete_targets = sampled_discrete_actions, continuous_targets = sampled_continuous_actions, ) @@ -2424,6 +2472,7 @@ class DynamicsWorldModel(Module): if ( self.num_agents == 1 and add_autoregressive_action_loss and + time > 1, (exists(discrete_actions) or exists(continuous_actions)) ): assert self.action_embedder.has_actions @@ -2433,17 +2482,35 @@ class DynamicsWorldModel(Module): agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d') policy_embed = self.policy_head(agent_tokens[:, :-1]) + # constitute multi token prediction targets + + discrete_action_targets = continuous_action_targets = None + + if exists(discrete_actions): + discrete_action_targets, discrete_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len) + discrete_action_targets = rearrange(discrete_action_targets, 'b t mtp ... -> mtp b t ...') + discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t') + + if exists(continuous_actions): + continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len) + continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...') + continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t') + discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs( policy_embed, - discrete_targets = discrete_actions[:, 1:] if exists(discrete_actions) else None, - continuous_targets = continuous_actions[:, 1:] if exists(continuous_actions) else None + discrete_targets = discrete_action_targets if exists(discrete_actions) else None, + continuous_targets = continuous_action_targets if exists(continuous_actions) else None ) if exists(discrete_log_probs): - behavior_clone_loss = behavior_clone_loss - discrete_log_probs.sum(dim = -1).mean() + discrete_log_probs = discrete_log_probs.masked_fill(discrete_mask[..., None], 0.) + + behavior_clone_loss = behavior_clone_loss - reduce(discrete_log_probs, 'mtp b t na -> b t', 'sum').mean() if exists(continuous_log_probs): - behavior_clone_loss = behavior_clone_loss - continuous_log_probs.sum(dim = -1).mean() + continuous_log_probs = continuous_log_probs.masked_fill(continuous_mask[..., None], 0.) + + behavior_clone_loss = behavior_clone_loss - reduce(continuous_log_probs, 'mtp b t na -> b t', 'sum').mean() # gather losses diff --git a/pyproject.toml b/pyproject.toml index bc67a5f..9e4571c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.37" +version = "0.0.38" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 943dcbe..e1c996c 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -401,3 +401,25 @@ def test_mtp(): assert action_targets.shape == (3, 15, 3, 2) assert mask.shape == (3, 15, 3) + + from dreamer4.dreamer4 import ActionEmbedder + + embedder = ActionEmbedder( + 512, + num_discrete_actions = (4, 4), + num_continuous_actions = 2, + can_unembed = True, + num_unembed_preds = 8 + ) + + discrete_actions = torch.randint(0, 4, (2, 3, 2)) + continuous_actions = torch.randn(2, 3, 2) + + action_embed = torch.randn(2, 16, 512) + discrete_logits, continuous_logits = embedder.unembed(action_embed) + + assert discrete_logits.shape == (8, 2, 16, 8) + + discrete_logits, continuous_logits = embedder.unembed(action_embed, pred_head_index = 0) + + assert discrete_logits.shape == (2, 16, 8)