complete multi-token prediction for actions, tackle loss balancing another day
This commit is contained in:
parent
bc629d78b1
commit
b6aa19f31e
@ -124,6 +124,15 @@ def safe_cat(tensors, dim):
|
||||
|
||||
return cat(tensors, dim = dim)
|
||||
|
||||
def safe_squeeze_first(t):
|
||||
if not exists(t):
|
||||
return None
|
||||
|
||||
if t.shape[0] != 1:
|
||||
return t
|
||||
|
||||
return rearrange(t, '1 ... -> ...')
|
||||
|
||||
def gumbel_noise(t):
|
||||
noise = torch.rand_like(t)
|
||||
return -log(-log(noise))
|
||||
@ -365,7 +374,9 @@ class ActionEmbedder(Module):
|
||||
num_continuous_actions = 0,
|
||||
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None,
|
||||
can_unembed = False,
|
||||
unembed_dim = None
|
||||
unembed_dim = None,
|
||||
num_unembed_preds = 1,
|
||||
squeeze_unembed_preds = True # will auto-squeeze if prediction is just 1
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -403,11 +414,14 @@ class ActionEmbedder(Module):
|
||||
|
||||
self.can_unembed = can_unembed
|
||||
|
||||
self.num_unembed_preds = num_unembed_preds
|
||||
self.squeeze_unembed_preds = squeeze_unembed_preds
|
||||
|
||||
if not can_unembed:
|
||||
return
|
||||
|
||||
unembed_dim = default(unembed_dim, dim)
|
||||
self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, unembed_dim) * 1e-2)
|
||||
self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, num_unembed_preds, unembed_dim) * 1e-2)
|
||||
|
||||
discrete_action_index = arange(total_discrete_actions)
|
||||
|
||||
@ -421,7 +435,7 @@ class ActionEmbedder(Module):
|
||||
|
||||
self.register_buffer('discrete_action_mask', discrete_action_mask, persistent = False)
|
||||
|
||||
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2)
|
||||
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, num_unembed_preds, unembed_dim, 2) * 1e-2)
|
||||
|
||||
def embed_parameters(self):
|
||||
return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
|
||||
@ -454,12 +468,26 @@ class ActionEmbedder(Module):
|
||||
embeds, # (... d)
|
||||
discrete_action_types = None, # (na)
|
||||
continuous_action_types = None, # (na)
|
||||
return_split_discrete = False
|
||||
return_split_discrete = False,
|
||||
pred_head_index: int | Tensor | None = None
|
||||
|
||||
): # (... discrete_na), (... continuous_na 2)
|
||||
|
||||
device = embeds.device
|
||||
|
||||
assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init'
|
||||
|
||||
# handle only one prediction head during inference
|
||||
|
||||
if exists(pred_head_index) and isinstance(pred_head_index, int):
|
||||
pred_head_index = tensor(pred_head_index, device = device)
|
||||
|
||||
# if pred_head_index given as a solo int, just assume we want to squeeze out the prediction head dimension
|
||||
|
||||
squeeze_one_pred_head = exists(pred_head_index) and pred_head_index.ndim == 0
|
||||
|
||||
# get action types
|
||||
|
||||
discrete_action_types, continuous_action_types = tuple(self.cast_action_types(t) for t in (discrete_action_types, continuous_action_types))
|
||||
|
||||
# discrete actions
|
||||
@ -475,7 +503,13 @@ class ActionEmbedder(Module):
|
||||
|
||||
discrete_action_unembed = discrete_action_unembed[discrete_action_mask]
|
||||
|
||||
discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na d -> ... na')
|
||||
if exists(pred_head_index):
|
||||
discrete_action_unembed = discrete_action_unembed.index_select(1, pred_head_index)
|
||||
|
||||
discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na mtp d -> mtp ... na')
|
||||
|
||||
if self.squeeze_unembed_preds or squeeze_one_pred_head:
|
||||
discrete_action_logits = safe_squeeze_first(discrete_action_logits)
|
||||
|
||||
# whether to split the discrete action logits by the number of actions per action type
|
||||
|
||||
@ -496,7 +530,15 @@ class ActionEmbedder(Module):
|
||||
if exists(continuous_action_types):
|
||||
continuous_action_unembed = continuous_action_unembed[continuous_action_types]
|
||||
|
||||
continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na d two -> ... na two')
|
||||
if isinstance(pred_head_index, int):
|
||||
continuous_action_unembed = continuous_action_unembed[:, pred_head_index]
|
||||
else:
|
||||
continuous_action_unembed = continuous_action_unembed.index_select(1, pred_head_index)
|
||||
|
||||
continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na mtp d two -> mtp ... na two')
|
||||
|
||||
if self.squeeze_unembed_preds or squeeze_one_pred_head:
|
||||
continuous_action_mean_log_var = safe_squeeze_first(continuous_action_mean_log_var)
|
||||
|
||||
return discrete_action_logits, continuous_action_mean_log_var
|
||||
|
||||
@ -506,11 +548,13 @@ class ActionEmbedder(Module):
|
||||
discrete_temperature = 1.,
|
||||
continuous_temperature = 1.,
|
||||
inverse_norm_continuous = None,
|
||||
pred_head_index: int | Tensor | None = None,
|
||||
squeeze = True,
|
||||
**kwargs
|
||||
):
|
||||
inverse_norm_continuous = default(inverse_norm_continuous, self.continuous_need_norm)
|
||||
|
||||
discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, **kwargs)
|
||||
discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, pred_head_index = pred_head_index, **kwargs)
|
||||
|
||||
sampled_discrete = sampled_continuous = None
|
||||
|
||||
@ -543,12 +587,13 @@ class ActionEmbedder(Module):
|
||||
continuous_targets = None, # (... na)
|
||||
discrete_action_types = None, # (na)
|
||||
continuous_action_types = None, # (na)
|
||||
pred_head_index: int | Tensor | None = None,
|
||||
parallel_discrete_calc = None,
|
||||
return_entropies = False
|
||||
):
|
||||
parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
|
||||
|
||||
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
|
||||
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, pred_head_index = pred_head_index, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
|
||||
|
||||
# discrete
|
||||
|
||||
@ -1651,7 +1696,9 @@ class DynamicsWorldModel(Module):
|
||||
num_continuous_actions = num_continuous_actions,
|
||||
continuous_norm_stats = continuous_norm_stats,
|
||||
can_unembed = True,
|
||||
unembed_dim = dim * 4
|
||||
unembed_dim = dim * 4,
|
||||
num_unembed_preds = multi_token_pred_len,
|
||||
squeeze_unembed_preds = False
|
||||
)
|
||||
|
||||
self.behavior_clone_weight = behavior_clone_weight
|
||||
@ -1790,7 +1837,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
policy_embed = self.policy_head(agent_embed)
|
||||
|
||||
log_probs, entropies = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
|
||||
log_probs, entropies = self.action_embedder.log_probs(policy_embed, pred_head_index = 0, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
|
||||
|
||||
# concat discrete and continuous actions into one for optimizing
|
||||
|
||||
@ -1981,7 +2028,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
policy_embed = self.policy_head(one_agent_embed)
|
||||
|
||||
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed)
|
||||
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
||||
|
||||
decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
|
||||
decoded_continuous_actions = safe_cat((decoded_continuous_actions, sampled_continuous_actions), dim = 1)
|
||||
@ -1989,6 +2036,7 @@ class DynamicsWorldModel(Module):
|
||||
if return_log_probs_and_values:
|
||||
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
|
||||
policy_embed,
|
||||
pred_head_index = 0,
|
||||
discrete_targets = sampled_discrete_actions,
|
||||
continuous_targets = sampled_continuous_actions,
|
||||
)
|
||||
@ -2424,6 +2472,7 @@ class DynamicsWorldModel(Module):
|
||||
if (
|
||||
self.num_agents == 1 and
|
||||
add_autoregressive_action_loss and
|
||||
time > 1,
|
||||
(exists(discrete_actions) or exists(continuous_actions))
|
||||
):
|
||||
assert self.action_embedder.has_actions
|
||||
@ -2433,17 +2482,35 @@ class DynamicsWorldModel(Module):
|
||||
agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d')
|
||||
policy_embed = self.policy_head(agent_tokens[:, :-1])
|
||||
|
||||
# constitute multi token prediction targets
|
||||
|
||||
discrete_action_targets = continuous_action_targets = None
|
||||
|
||||
if exists(discrete_actions):
|
||||
discrete_action_targets, discrete_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
|
||||
discrete_action_targets = rearrange(discrete_action_targets, 'b t mtp ... -> mtp b t ...')
|
||||
discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
|
||||
|
||||
if exists(continuous_actions):
|
||||
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
|
||||
continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
|
||||
continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
|
||||
|
||||
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
|
||||
policy_embed,
|
||||
discrete_targets = discrete_actions[:, 1:] if exists(discrete_actions) else None,
|
||||
continuous_targets = continuous_actions[:, 1:] if exists(continuous_actions) else None
|
||||
discrete_targets = discrete_action_targets if exists(discrete_actions) else None,
|
||||
continuous_targets = continuous_action_targets if exists(continuous_actions) else None
|
||||
)
|
||||
|
||||
if exists(discrete_log_probs):
|
||||
behavior_clone_loss = behavior_clone_loss - discrete_log_probs.sum(dim = -1).mean()
|
||||
discrete_log_probs = discrete_log_probs.masked_fill(discrete_mask[..., None], 0.)
|
||||
|
||||
behavior_clone_loss = behavior_clone_loss - reduce(discrete_log_probs, 'mtp b t na -> b t', 'sum').mean()
|
||||
|
||||
if exists(continuous_log_probs):
|
||||
behavior_clone_loss = behavior_clone_loss - continuous_log_probs.sum(dim = -1).mean()
|
||||
continuous_log_probs = continuous_log_probs.masked_fill(continuous_mask[..., None], 0.)
|
||||
|
||||
behavior_clone_loss = behavior_clone_loss - reduce(continuous_log_probs, 'mtp b t na -> b t', 'sum').mean()
|
||||
|
||||
# gather losses
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.37"
|
||||
version = "0.0.38"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -401,3 +401,25 @@ def test_mtp():
|
||||
|
||||
assert action_targets.shape == (3, 15, 3, 2)
|
||||
assert mask.shape == (3, 15, 3)
|
||||
|
||||
from dreamer4.dreamer4 import ActionEmbedder
|
||||
|
||||
embedder = ActionEmbedder(
|
||||
512,
|
||||
num_discrete_actions = (4, 4),
|
||||
num_continuous_actions = 2,
|
||||
can_unembed = True,
|
||||
num_unembed_preds = 8
|
||||
)
|
||||
|
||||
discrete_actions = torch.randint(0, 4, (2, 3, 2))
|
||||
continuous_actions = torch.randn(2, 3, 2)
|
||||
|
||||
action_embed = torch.randn(2, 16, 512)
|
||||
discrete_logits, continuous_logits = embedder.unembed(action_embed)
|
||||
|
||||
assert discrete_logits.shape == (8, 2, 16, 8)
|
||||
|
||||
discrete_logits, continuous_logits = embedder.unembed(action_embed, pred_head_index = 0)
|
||||
|
||||
assert discrete_logits.shape == (2, 16, 8)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user