complete multi token prediction for the reward head
This commit is contained in:
parent
911a1a8434
commit
8c88a33d3b
@ -20,6 +20,7 @@ from torchvision.models import VGG16_Weights
|
||||
from torch.optim import Optimizer
|
||||
from adam_atan2_pytorch import MuonAdamAtan2
|
||||
|
||||
from x_mlps_pytorch.ensemble import Ensemble
|
||||
from x_mlps_pytorch.normed_mlp import create_mlp
|
||||
|
||||
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
||||
@ -40,6 +41,7 @@ from assoc_scan import AssocScan
|
||||
# g - groups of query heads to key heads (gqa)
|
||||
# vc - video channels
|
||||
# vh, vw - video height and width
|
||||
# mtp - multi token prediction length
|
||||
|
||||
import einx
|
||||
from einx import add, multiply
|
||||
@ -175,7 +177,7 @@ def softclamp(t, value = 50.):
|
||||
|
||||
def create_multi_token_prediction_targets(
|
||||
t, # (b t ...)
|
||||
steps_future
|
||||
steps_future,
|
||||
|
||||
): # (b t-1 steps ...), (b t-1 steps) - targets and the mask, where mask is False for padding
|
||||
|
||||
@ -193,7 +195,9 @@ def create_multi_token_prediction_targets(
|
||||
indices[~mask] = 0
|
||||
mask = repeat(mask, 't steps -> b t steps', b = batch)
|
||||
|
||||
return t[batch_arange, indices], mask
|
||||
out = t[batch_arange, indices]
|
||||
|
||||
return out, mask
|
||||
|
||||
# loss related
|
||||
|
||||
@ -1515,6 +1519,7 @@ class DynamicsWorldModel(Module):
|
||||
num_continuous_actions = 0,
|
||||
continuous_norm_stats = None,
|
||||
reward_loss_weight = 0.1,
|
||||
multi_token_pred_len = 8,
|
||||
value_head_mlp_depth = 3,
|
||||
policy_head_mlp_depth = 3,
|
||||
behavior_clone_weight = 0.1,
|
||||
@ -1642,6 +1647,10 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
self.behavior_clone_weight = behavior_clone_weight
|
||||
|
||||
# multi token prediction length
|
||||
|
||||
self.multi_token_pred_len = multi_token_pred_len
|
||||
|
||||
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
||||
|
||||
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
||||
@ -1653,11 +1662,16 @@ class DynamicsWorldModel(Module):
|
||||
learned_embedding = add_reward_embed_to_agent_token
|
||||
)
|
||||
|
||||
self.to_reward_pred = Sequential(
|
||||
to_reward_pred = Sequential(
|
||||
RMSNorm(dim),
|
||||
LinearNoBias(dim, self.reward_encoder.num_bins)
|
||||
)
|
||||
|
||||
self.to_reward_pred = Ensemble(
|
||||
to_reward_pred,
|
||||
multi_token_pred_len
|
||||
)
|
||||
|
||||
self.reward_loss_weight = reward_loss_weight
|
||||
|
||||
# value head
|
||||
@ -1944,7 +1958,7 @@ class DynamicsWorldModel(Module):
|
||||
if return_rewards_per_frame:
|
||||
one_agent_embed = agent_embed[:, -1:, agent_index]
|
||||
|
||||
reward_logits = self.to_reward_pred(one_agent_embed)
|
||||
reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0)
|
||||
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
|
||||
|
||||
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
||||
@ -2380,9 +2394,19 @@ class DynamicsWorldModel(Module):
|
||||
if rewards.ndim == 2: # (b t)
|
||||
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
|
||||
|
||||
reward_pred = self.to_reward_pred(encoded_agent_tokens)
|
||||
reward_pred = self.to_reward_pred(encoded_agent_tokens[:, :-1])
|
||||
|
||||
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
|
||||
reward_pred = rearrange(reward_pred, 'mtp b t l -> b l t mtp')
|
||||
|
||||
reward_targets, reward_loss_mask = create_multi_token_prediction_targets(two_hot_encoding, self.multi_token_pred_len)
|
||||
|
||||
reward_targets = rearrange(reward_targets, 'b t mtp l -> b l t mtp')
|
||||
|
||||
reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
|
||||
|
||||
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
||||
|
||||
reward_loss = reward_losses.sum(dim = -1).mean() # they sum across the prediction steps - eq(9)
|
||||
|
||||
# maybe autoregressive action loss
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.34"
|
||||
version = "0.0.35"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user