complete multi token prediction for the reward head

This commit is contained in:
lucidrains 2025-10-18 08:33:06 -07:00
parent 911a1a8434
commit 8c88a33d3b
2 changed files with 31 additions and 7 deletions

View File

@ -20,6 +20,7 @@ from torchvision.models import VGG16_Weights
from torch.optim import Optimizer from torch.optim import Optimizer
from adam_atan2_pytorch import MuonAdamAtan2 from adam_atan2_pytorch import MuonAdamAtan2
from x_mlps_pytorch.ensemble import Ensemble
from x_mlps_pytorch.normed_mlp import create_mlp from x_mlps_pytorch.normed_mlp import create_mlp
from hyper_connections import get_init_and_expand_reduce_stream_functions from hyper_connections import get_init_and_expand_reduce_stream_functions
@ -40,6 +41,7 @@ from assoc_scan import AssocScan
# g - groups of query heads to key heads (gqa) # g - groups of query heads to key heads (gqa)
# vc - video channels # vc - video channels
# vh, vw - video height and width # vh, vw - video height and width
# mtp - multi token prediction length
import einx import einx
from einx import add, multiply from einx import add, multiply
@ -175,7 +177,7 @@ def softclamp(t, value = 50.):
def create_multi_token_prediction_targets( def create_multi_token_prediction_targets(
t, # (b t ...) t, # (b t ...)
steps_future steps_future,
): # (b t-1 steps ...), (b t-1 steps) - targets and the mask, where mask is False for padding ): # (b t-1 steps ...), (b t-1 steps) - targets and the mask, where mask is False for padding
@ -193,7 +195,9 @@ def create_multi_token_prediction_targets(
indices[~mask] = 0 indices[~mask] = 0
mask = repeat(mask, 't steps -> b t steps', b = batch) mask = repeat(mask, 't steps -> b t steps', b = batch)
return t[batch_arange, indices], mask out = t[batch_arange, indices]
return out, mask
# loss related # loss related
@ -1515,6 +1519,7 @@ class DynamicsWorldModel(Module):
num_continuous_actions = 0, num_continuous_actions = 0,
continuous_norm_stats = None, continuous_norm_stats = None,
reward_loss_weight = 0.1, reward_loss_weight = 0.1,
multi_token_pred_len = 8,
value_head_mlp_depth = 3, value_head_mlp_depth = 3,
policy_head_mlp_depth = 3, policy_head_mlp_depth = 3,
behavior_clone_weight = 0.1, behavior_clone_weight = 0.1,
@ -1642,6 +1647,10 @@ class DynamicsWorldModel(Module):
self.behavior_clone_weight = behavior_clone_weight self.behavior_clone_weight = behavior_clone_weight
# multi token prediction length
self.multi_token_pred_len = multi_token_pred_len
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token # each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
@ -1653,11 +1662,16 @@ class DynamicsWorldModel(Module):
learned_embedding = add_reward_embed_to_agent_token learned_embedding = add_reward_embed_to_agent_token
) )
self.to_reward_pred = Sequential( to_reward_pred = Sequential(
RMSNorm(dim), RMSNorm(dim),
LinearNoBias(dim, self.reward_encoder.num_bins) LinearNoBias(dim, self.reward_encoder.num_bins)
) )
self.to_reward_pred = Ensemble(
to_reward_pred,
multi_token_pred_len
)
self.reward_loss_weight = reward_loss_weight self.reward_loss_weight = reward_loss_weight
# value head # value head
@ -1944,7 +1958,7 @@ class DynamicsWorldModel(Module):
if return_rewards_per_frame: if return_rewards_per_frame:
one_agent_embed = agent_embed[:, -1:, agent_index] one_agent_embed = agent_embed[:, -1:, agent_index]
reward_logits = self.to_reward_pred(one_agent_embed) reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0)
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True) pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1) decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
@ -2380,9 +2394,19 @@ class DynamicsWorldModel(Module):
if rewards.ndim == 2: # (b t) if rewards.ndim == 2: # (b t)
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean') encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
reward_pred = self.to_reward_pred(encoded_agent_tokens) reward_pred = self.to_reward_pred(encoded_agent_tokens[:, :-1])
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding) reward_pred = rearrange(reward_pred, 'mtp b t l -> b l t mtp')
reward_targets, reward_loss_mask = create_multi_token_prediction_targets(two_hot_encoding, self.multi_token_pred_len)
reward_targets = rearrange(reward_targets, 'b t mtp l -> b l t mtp')
reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
reward_loss = reward_losses.sum(dim = -1).mean() # they sum across the prediction steps - eq(9)
# maybe autoregressive action loss # maybe autoregressive action loss

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.34" version = "0.0.35"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }