complete multi token prediction for the reward head
This commit is contained in:
parent
911a1a8434
commit
8c88a33d3b
@ -20,6 +20,7 @@ from torchvision.models import VGG16_Weights
|
|||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from adam_atan2_pytorch import MuonAdamAtan2
|
from adam_atan2_pytorch import MuonAdamAtan2
|
||||||
|
|
||||||
|
from x_mlps_pytorch.ensemble import Ensemble
|
||||||
from x_mlps_pytorch.normed_mlp import create_mlp
|
from x_mlps_pytorch.normed_mlp import create_mlp
|
||||||
|
|
||||||
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
||||||
@ -40,6 +41,7 @@ from assoc_scan import AssocScan
|
|||||||
# g - groups of query heads to key heads (gqa)
|
# g - groups of query heads to key heads (gqa)
|
||||||
# vc - video channels
|
# vc - video channels
|
||||||
# vh, vw - video height and width
|
# vh, vw - video height and width
|
||||||
|
# mtp - multi token prediction length
|
||||||
|
|
||||||
import einx
|
import einx
|
||||||
from einx import add, multiply
|
from einx import add, multiply
|
||||||
@ -175,7 +177,7 @@ def softclamp(t, value = 50.):
|
|||||||
|
|
||||||
def create_multi_token_prediction_targets(
|
def create_multi_token_prediction_targets(
|
||||||
t, # (b t ...)
|
t, # (b t ...)
|
||||||
steps_future
|
steps_future,
|
||||||
|
|
||||||
): # (b t-1 steps ...), (b t-1 steps) - targets and the mask, where mask is False for padding
|
): # (b t-1 steps ...), (b t-1 steps) - targets and the mask, where mask is False for padding
|
||||||
|
|
||||||
@ -193,7 +195,9 @@ def create_multi_token_prediction_targets(
|
|||||||
indices[~mask] = 0
|
indices[~mask] = 0
|
||||||
mask = repeat(mask, 't steps -> b t steps', b = batch)
|
mask = repeat(mask, 't steps -> b t steps', b = batch)
|
||||||
|
|
||||||
return t[batch_arange, indices], mask
|
out = t[batch_arange, indices]
|
||||||
|
|
||||||
|
return out, mask
|
||||||
|
|
||||||
# loss related
|
# loss related
|
||||||
|
|
||||||
@ -1515,6 +1519,7 @@ class DynamicsWorldModel(Module):
|
|||||||
num_continuous_actions = 0,
|
num_continuous_actions = 0,
|
||||||
continuous_norm_stats = None,
|
continuous_norm_stats = None,
|
||||||
reward_loss_weight = 0.1,
|
reward_loss_weight = 0.1,
|
||||||
|
multi_token_pred_len = 8,
|
||||||
value_head_mlp_depth = 3,
|
value_head_mlp_depth = 3,
|
||||||
policy_head_mlp_depth = 3,
|
policy_head_mlp_depth = 3,
|
||||||
behavior_clone_weight = 0.1,
|
behavior_clone_weight = 0.1,
|
||||||
@ -1642,6 +1647,10 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
self.behavior_clone_weight = behavior_clone_weight
|
self.behavior_clone_weight = behavior_clone_weight
|
||||||
|
|
||||||
|
# multi token prediction length
|
||||||
|
|
||||||
|
self.multi_token_pred_len = multi_token_pred_len
|
||||||
|
|
||||||
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
||||||
|
|
||||||
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
||||||
@ -1653,11 +1662,16 @@ class DynamicsWorldModel(Module):
|
|||||||
learned_embedding = add_reward_embed_to_agent_token
|
learned_embedding = add_reward_embed_to_agent_token
|
||||||
)
|
)
|
||||||
|
|
||||||
self.to_reward_pred = Sequential(
|
to_reward_pred = Sequential(
|
||||||
RMSNorm(dim),
|
RMSNorm(dim),
|
||||||
LinearNoBias(dim, self.reward_encoder.num_bins)
|
LinearNoBias(dim, self.reward_encoder.num_bins)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.to_reward_pred = Ensemble(
|
||||||
|
to_reward_pred,
|
||||||
|
multi_token_pred_len
|
||||||
|
)
|
||||||
|
|
||||||
self.reward_loss_weight = reward_loss_weight
|
self.reward_loss_weight = reward_loss_weight
|
||||||
|
|
||||||
# value head
|
# value head
|
||||||
@ -1944,7 +1958,7 @@ class DynamicsWorldModel(Module):
|
|||||||
if return_rewards_per_frame:
|
if return_rewards_per_frame:
|
||||||
one_agent_embed = agent_embed[:, -1:, agent_index]
|
one_agent_embed = agent_embed[:, -1:, agent_index]
|
||||||
|
|
||||||
reward_logits = self.to_reward_pred(one_agent_embed)
|
reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0)
|
||||||
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
|
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
|
||||||
|
|
||||||
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
||||||
@ -2380,9 +2394,19 @@ class DynamicsWorldModel(Module):
|
|||||||
if rewards.ndim == 2: # (b t)
|
if rewards.ndim == 2: # (b t)
|
||||||
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
|
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
|
||||||
|
|
||||||
reward_pred = self.to_reward_pred(encoded_agent_tokens)
|
reward_pred = self.to_reward_pred(encoded_agent_tokens[:, :-1])
|
||||||
|
|
||||||
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
|
reward_pred = rearrange(reward_pred, 'mtp b t l -> b l t mtp')
|
||||||
|
|
||||||
|
reward_targets, reward_loss_mask = create_multi_token_prediction_targets(two_hot_encoding, self.multi_token_pred_len)
|
||||||
|
|
||||||
|
reward_targets = rearrange(reward_targets, 'b t mtp l -> b l t mtp')
|
||||||
|
|
||||||
|
reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
|
||||||
|
|
||||||
|
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
||||||
|
|
||||||
|
reward_loss = reward_losses.sum(dim = -1).mean() # they sum across the prediction steps - eq(9)
|
||||||
|
|
||||||
# maybe autoregressive action loss
|
# maybe autoregressive action loss
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.34"
|
version = "0.0.35"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user