complete multi token prediction for the reward head

This commit is contained in:
lucidrains 2025-10-18 08:33:06 -07:00
parent 911a1a8434
commit 8c88a33d3b
2 changed files with 31 additions and 7 deletions

View File

@ -20,6 +20,7 @@ from torchvision.models import VGG16_Weights
from torch.optim import Optimizer
from adam_atan2_pytorch import MuonAdamAtan2
from x_mlps_pytorch.ensemble import Ensemble
from x_mlps_pytorch.normed_mlp import create_mlp
from hyper_connections import get_init_and_expand_reduce_stream_functions
@ -40,6 +41,7 @@ from assoc_scan import AssocScan
# g - groups of query heads to key heads (gqa)
# vc - video channels
# vh, vw - video height and width
# mtp - multi token prediction length
import einx
from einx import add, multiply
@ -175,7 +177,7 @@ def softclamp(t, value = 50.):
def create_multi_token_prediction_targets(
t, # (b t ...)
steps_future
steps_future,
): # (b t-1 steps ...), (b t-1 steps) - targets and the mask, where mask is False for padding
@ -193,7 +195,9 @@ def create_multi_token_prediction_targets(
indices[~mask] = 0
mask = repeat(mask, 't steps -> b t steps', b = batch)
return t[batch_arange, indices], mask
out = t[batch_arange, indices]
return out, mask
# loss related
@ -1515,6 +1519,7 @@ class DynamicsWorldModel(Module):
num_continuous_actions = 0,
continuous_norm_stats = None,
reward_loss_weight = 0.1,
multi_token_pred_len = 8,
value_head_mlp_depth = 3,
policy_head_mlp_depth = 3,
behavior_clone_weight = 0.1,
@ -1642,6 +1647,10 @@ class DynamicsWorldModel(Module):
self.behavior_clone_weight = behavior_clone_weight
# multi token prediction length
self.multi_token_pred_len = multi_token_pred_len
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
@ -1653,11 +1662,16 @@ class DynamicsWorldModel(Module):
learned_embedding = add_reward_embed_to_agent_token
)
self.to_reward_pred = Sequential(
to_reward_pred = Sequential(
RMSNorm(dim),
LinearNoBias(dim, self.reward_encoder.num_bins)
)
self.to_reward_pred = Ensemble(
to_reward_pred,
multi_token_pred_len
)
self.reward_loss_weight = reward_loss_weight
# value head
@ -1944,7 +1958,7 @@ class DynamicsWorldModel(Module):
if return_rewards_per_frame:
one_agent_embed = agent_embed[:, -1:, agent_index]
reward_logits = self.to_reward_pred(one_agent_embed)
reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0)
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
@ -2380,9 +2394,19 @@ class DynamicsWorldModel(Module):
if rewards.ndim == 2: # (b t)
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
reward_pred = self.to_reward_pred(encoded_agent_tokens)
reward_pred = self.to_reward_pred(encoded_agent_tokens[:, :-1])
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
reward_pred = rearrange(reward_pred, 'mtp b t l -> b l t mtp')
reward_targets, reward_loss_mask = create_multi_token_prediction_targets(two_hot_encoding, self.multi_token_pred_len)
reward_targets = rearrange(reward_targets, 'b t mtp l -> b l t mtp')
reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
reward_loss = reward_losses.sum(dim = -1).mean() # they sum across the prediction steps - eq(9)
# maybe autoregressive action loss

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.34"
version = "0.0.35"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }