From 8c88a33d3ba1f9694097ebb1f4521ac5f336af4b Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 18 Oct 2025 08:33:06 -0700 Subject: [PATCH] complete multi token prediction for the reward head --- dreamer4/dreamer4.py | 36 ++++++++++++++++++++++++++++++------ pyproject.toml | 2 +- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index a12d86a..4d4d7dc 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -20,6 +20,7 @@ from torchvision.models import VGG16_Weights from torch.optim import Optimizer from adam_atan2_pytorch import MuonAdamAtan2 +from x_mlps_pytorch.ensemble import Ensemble from x_mlps_pytorch.normed_mlp import create_mlp from hyper_connections import get_init_and_expand_reduce_stream_functions @@ -40,6 +41,7 @@ from assoc_scan import AssocScan # g - groups of query heads to key heads (gqa) # vc - video channels # vh, vw - video height and width +# mtp - multi token prediction length import einx from einx import add, multiply @@ -175,7 +177,7 @@ def softclamp(t, value = 50.): def create_multi_token_prediction_targets( t, # (b t ...) - steps_future + steps_future, ): # (b t-1 steps ...), (b t-1 steps) - targets and the mask, where mask is False for padding @@ -193,7 +195,9 @@ def create_multi_token_prediction_targets( indices[~mask] = 0 mask = repeat(mask, 't steps -> b t steps', b = batch) - return t[batch_arange, indices], mask + out = t[batch_arange, indices] + + return out, mask # loss related @@ -1515,6 +1519,7 @@ class DynamicsWorldModel(Module): num_continuous_actions = 0, continuous_norm_stats = None, reward_loss_weight = 0.1, + multi_token_pred_len = 8, value_head_mlp_depth = 3, policy_head_mlp_depth = 3, behavior_clone_weight = 0.1, @@ -1642,6 +1647,10 @@ class DynamicsWorldModel(Module): self.behavior_clone_weight = behavior_clone_weight + # multi token prediction length + + self.multi_token_pred_len = multi_token_pred_len + # each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token @@ -1653,11 +1662,16 @@ class DynamicsWorldModel(Module): learned_embedding = add_reward_embed_to_agent_token ) - self.to_reward_pred = Sequential( + to_reward_pred = Sequential( RMSNorm(dim), LinearNoBias(dim, self.reward_encoder.num_bins) ) + self.to_reward_pred = Ensemble( + to_reward_pred, + multi_token_pred_len + ) + self.reward_loss_weight = reward_loss_weight # value head @@ -1944,7 +1958,7 @@ class DynamicsWorldModel(Module): if return_rewards_per_frame: one_agent_embed = agent_embed[:, -1:, agent_index] - reward_logits = self.to_reward_pred(one_agent_embed) + reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0) pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True) decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1) @@ -2380,9 +2394,19 @@ class DynamicsWorldModel(Module): if rewards.ndim == 2: # (b t) encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean') - reward_pred = self.to_reward_pred(encoded_agent_tokens) + reward_pred = self.to_reward_pred(encoded_agent_tokens[:, :-1]) - reward_loss = F.cross_entropy(reward_pred, two_hot_encoding) + reward_pred = rearrange(reward_pred, 'mtp b t l -> b l t mtp') + + reward_targets, reward_loss_mask = create_multi_token_prediction_targets(two_hot_encoding, self.multi_token_pred_len) + + reward_targets = rearrange(reward_targets, 'b t mtp l -> b l t mtp') + + reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none') + + reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.) + + reward_loss = reward_losses.sum(dim = -1).mean() # they sum across the prediction steps - eq(9) # maybe autoregressive action loss diff --git a/pyproject.toml b/pyproject.toml index eac1e34..d48ab68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.34" +version = "0.0.35" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }