complete multi-token prediction for actions, tackle loss balancing another day
This commit is contained in:
parent
bc629d78b1
commit
b6aa19f31e
@ -124,6 +124,15 @@ def safe_cat(tensors, dim):
|
|||||||
|
|
||||||
return cat(tensors, dim = dim)
|
return cat(tensors, dim = dim)
|
||||||
|
|
||||||
|
def safe_squeeze_first(t):
|
||||||
|
if not exists(t):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if t.shape[0] != 1:
|
||||||
|
return t
|
||||||
|
|
||||||
|
return rearrange(t, '1 ... -> ...')
|
||||||
|
|
||||||
def gumbel_noise(t):
|
def gumbel_noise(t):
|
||||||
noise = torch.rand_like(t)
|
noise = torch.rand_like(t)
|
||||||
return -log(-log(noise))
|
return -log(-log(noise))
|
||||||
@ -365,7 +374,9 @@ class ActionEmbedder(Module):
|
|||||||
num_continuous_actions = 0,
|
num_continuous_actions = 0,
|
||||||
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None,
|
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None,
|
||||||
can_unembed = False,
|
can_unembed = False,
|
||||||
unembed_dim = None
|
unembed_dim = None,
|
||||||
|
num_unembed_preds = 1,
|
||||||
|
squeeze_unembed_preds = True # will auto-squeeze if prediction is just 1
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -403,11 +414,14 @@ class ActionEmbedder(Module):
|
|||||||
|
|
||||||
self.can_unembed = can_unembed
|
self.can_unembed = can_unembed
|
||||||
|
|
||||||
|
self.num_unembed_preds = num_unembed_preds
|
||||||
|
self.squeeze_unembed_preds = squeeze_unembed_preds
|
||||||
|
|
||||||
if not can_unembed:
|
if not can_unembed:
|
||||||
return
|
return
|
||||||
|
|
||||||
unembed_dim = default(unembed_dim, dim)
|
unembed_dim = default(unembed_dim, dim)
|
||||||
self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, unembed_dim) * 1e-2)
|
self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, num_unembed_preds, unembed_dim) * 1e-2)
|
||||||
|
|
||||||
discrete_action_index = arange(total_discrete_actions)
|
discrete_action_index = arange(total_discrete_actions)
|
||||||
|
|
||||||
@ -421,7 +435,7 @@ class ActionEmbedder(Module):
|
|||||||
|
|
||||||
self.register_buffer('discrete_action_mask', discrete_action_mask, persistent = False)
|
self.register_buffer('discrete_action_mask', discrete_action_mask, persistent = False)
|
||||||
|
|
||||||
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2)
|
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, num_unembed_preds, unembed_dim, 2) * 1e-2)
|
||||||
|
|
||||||
def embed_parameters(self):
|
def embed_parameters(self):
|
||||||
return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
|
return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
|
||||||
@ -454,12 +468,26 @@ class ActionEmbedder(Module):
|
|||||||
embeds, # (... d)
|
embeds, # (... d)
|
||||||
discrete_action_types = None, # (na)
|
discrete_action_types = None, # (na)
|
||||||
continuous_action_types = None, # (na)
|
continuous_action_types = None, # (na)
|
||||||
return_split_discrete = False
|
return_split_discrete = False,
|
||||||
|
pred_head_index: int | Tensor | None = None
|
||||||
|
|
||||||
): # (... discrete_na), (... continuous_na 2)
|
): # (... discrete_na), (... continuous_na 2)
|
||||||
|
|
||||||
|
device = embeds.device
|
||||||
|
|
||||||
assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init'
|
assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init'
|
||||||
|
|
||||||
|
# handle only one prediction head during inference
|
||||||
|
|
||||||
|
if exists(pred_head_index) and isinstance(pred_head_index, int):
|
||||||
|
pred_head_index = tensor(pred_head_index, device = device)
|
||||||
|
|
||||||
|
# if pred_head_index given as a solo int, just assume we want to squeeze out the prediction head dimension
|
||||||
|
|
||||||
|
squeeze_one_pred_head = exists(pred_head_index) and pred_head_index.ndim == 0
|
||||||
|
|
||||||
|
# get action types
|
||||||
|
|
||||||
discrete_action_types, continuous_action_types = tuple(self.cast_action_types(t) for t in (discrete_action_types, continuous_action_types))
|
discrete_action_types, continuous_action_types = tuple(self.cast_action_types(t) for t in (discrete_action_types, continuous_action_types))
|
||||||
|
|
||||||
# discrete actions
|
# discrete actions
|
||||||
@ -475,7 +503,13 @@ class ActionEmbedder(Module):
|
|||||||
|
|
||||||
discrete_action_unembed = discrete_action_unembed[discrete_action_mask]
|
discrete_action_unembed = discrete_action_unembed[discrete_action_mask]
|
||||||
|
|
||||||
discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na d -> ... na')
|
if exists(pred_head_index):
|
||||||
|
discrete_action_unembed = discrete_action_unembed.index_select(1, pred_head_index)
|
||||||
|
|
||||||
|
discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na mtp d -> mtp ... na')
|
||||||
|
|
||||||
|
if self.squeeze_unembed_preds or squeeze_one_pred_head:
|
||||||
|
discrete_action_logits = safe_squeeze_first(discrete_action_logits)
|
||||||
|
|
||||||
# whether to split the discrete action logits by the number of actions per action type
|
# whether to split the discrete action logits by the number of actions per action type
|
||||||
|
|
||||||
@ -496,7 +530,15 @@ class ActionEmbedder(Module):
|
|||||||
if exists(continuous_action_types):
|
if exists(continuous_action_types):
|
||||||
continuous_action_unembed = continuous_action_unembed[continuous_action_types]
|
continuous_action_unembed = continuous_action_unembed[continuous_action_types]
|
||||||
|
|
||||||
continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na d two -> ... na two')
|
if isinstance(pred_head_index, int):
|
||||||
|
continuous_action_unembed = continuous_action_unembed[:, pred_head_index]
|
||||||
|
else:
|
||||||
|
continuous_action_unembed = continuous_action_unembed.index_select(1, pred_head_index)
|
||||||
|
|
||||||
|
continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na mtp d two -> mtp ... na two')
|
||||||
|
|
||||||
|
if self.squeeze_unembed_preds or squeeze_one_pred_head:
|
||||||
|
continuous_action_mean_log_var = safe_squeeze_first(continuous_action_mean_log_var)
|
||||||
|
|
||||||
return discrete_action_logits, continuous_action_mean_log_var
|
return discrete_action_logits, continuous_action_mean_log_var
|
||||||
|
|
||||||
@ -506,11 +548,13 @@ class ActionEmbedder(Module):
|
|||||||
discrete_temperature = 1.,
|
discrete_temperature = 1.,
|
||||||
continuous_temperature = 1.,
|
continuous_temperature = 1.,
|
||||||
inverse_norm_continuous = None,
|
inverse_norm_continuous = None,
|
||||||
|
pred_head_index: int | Tensor | None = None,
|
||||||
|
squeeze = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
inverse_norm_continuous = default(inverse_norm_continuous, self.continuous_need_norm)
|
inverse_norm_continuous = default(inverse_norm_continuous, self.continuous_need_norm)
|
||||||
|
|
||||||
discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, **kwargs)
|
discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, pred_head_index = pred_head_index, **kwargs)
|
||||||
|
|
||||||
sampled_discrete = sampled_continuous = None
|
sampled_discrete = sampled_continuous = None
|
||||||
|
|
||||||
@ -543,12 +587,13 @@ class ActionEmbedder(Module):
|
|||||||
continuous_targets = None, # (... na)
|
continuous_targets = None, # (... na)
|
||||||
discrete_action_types = None, # (na)
|
discrete_action_types = None, # (na)
|
||||||
continuous_action_types = None, # (na)
|
continuous_action_types = None, # (na)
|
||||||
|
pred_head_index: int | Tensor | None = None,
|
||||||
parallel_discrete_calc = None,
|
parallel_discrete_calc = None,
|
||||||
return_entropies = False
|
return_entropies = False
|
||||||
):
|
):
|
||||||
parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
|
parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
|
||||||
|
|
||||||
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
|
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, pred_head_index = pred_head_index, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
|
||||||
|
|
||||||
# discrete
|
# discrete
|
||||||
|
|
||||||
@ -1651,7 +1696,9 @@ class DynamicsWorldModel(Module):
|
|||||||
num_continuous_actions = num_continuous_actions,
|
num_continuous_actions = num_continuous_actions,
|
||||||
continuous_norm_stats = continuous_norm_stats,
|
continuous_norm_stats = continuous_norm_stats,
|
||||||
can_unembed = True,
|
can_unembed = True,
|
||||||
unembed_dim = dim * 4
|
unembed_dim = dim * 4,
|
||||||
|
num_unembed_preds = multi_token_pred_len,
|
||||||
|
squeeze_unembed_preds = False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.behavior_clone_weight = behavior_clone_weight
|
self.behavior_clone_weight = behavior_clone_weight
|
||||||
@ -1790,7 +1837,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
policy_embed = self.policy_head(agent_embed)
|
policy_embed = self.policy_head(agent_embed)
|
||||||
|
|
||||||
log_probs, entropies = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
|
log_probs, entropies = self.action_embedder.log_probs(policy_embed, pred_head_index = 0, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
|
||||||
|
|
||||||
# concat discrete and continuous actions into one for optimizing
|
# concat discrete and continuous actions into one for optimizing
|
||||||
|
|
||||||
@ -1981,7 +2028,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
policy_embed = self.policy_head(one_agent_embed)
|
policy_embed = self.policy_head(one_agent_embed)
|
||||||
|
|
||||||
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed)
|
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
||||||
|
|
||||||
decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
|
decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
|
||||||
decoded_continuous_actions = safe_cat((decoded_continuous_actions, sampled_continuous_actions), dim = 1)
|
decoded_continuous_actions = safe_cat((decoded_continuous_actions, sampled_continuous_actions), dim = 1)
|
||||||
@ -1989,6 +2036,7 @@ class DynamicsWorldModel(Module):
|
|||||||
if return_log_probs_and_values:
|
if return_log_probs_and_values:
|
||||||
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
|
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
|
||||||
policy_embed,
|
policy_embed,
|
||||||
|
pred_head_index = 0,
|
||||||
discrete_targets = sampled_discrete_actions,
|
discrete_targets = sampled_discrete_actions,
|
||||||
continuous_targets = sampled_continuous_actions,
|
continuous_targets = sampled_continuous_actions,
|
||||||
)
|
)
|
||||||
@ -2424,6 +2472,7 @@ class DynamicsWorldModel(Module):
|
|||||||
if (
|
if (
|
||||||
self.num_agents == 1 and
|
self.num_agents == 1 and
|
||||||
add_autoregressive_action_loss and
|
add_autoregressive_action_loss and
|
||||||
|
time > 1,
|
||||||
(exists(discrete_actions) or exists(continuous_actions))
|
(exists(discrete_actions) or exists(continuous_actions))
|
||||||
):
|
):
|
||||||
assert self.action_embedder.has_actions
|
assert self.action_embedder.has_actions
|
||||||
@ -2433,17 +2482,35 @@ class DynamicsWorldModel(Module):
|
|||||||
agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d')
|
agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d')
|
||||||
policy_embed = self.policy_head(agent_tokens[:, :-1])
|
policy_embed = self.policy_head(agent_tokens[:, :-1])
|
||||||
|
|
||||||
|
# constitute multi token prediction targets
|
||||||
|
|
||||||
|
discrete_action_targets = continuous_action_targets = None
|
||||||
|
|
||||||
|
if exists(discrete_actions):
|
||||||
|
discrete_action_targets, discrete_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
|
||||||
|
discrete_action_targets = rearrange(discrete_action_targets, 'b t mtp ... -> mtp b t ...')
|
||||||
|
discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
|
||||||
|
|
||||||
|
if exists(continuous_actions):
|
||||||
|
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
|
||||||
|
continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
|
||||||
|
continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
|
||||||
|
|
||||||
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
|
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
|
||||||
policy_embed,
|
policy_embed,
|
||||||
discrete_targets = discrete_actions[:, 1:] if exists(discrete_actions) else None,
|
discrete_targets = discrete_action_targets if exists(discrete_actions) else None,
|
||||||
continuous_targets = continuous_actions[:, 1:] if exists(continuous_actions) else None
|
continuous_targets = continuous_action_targets if exists(continuous_actions) else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if exists(discrete_log_probs):
|
if exists(discrete_log_probs):
|
||||||
behavior_clone_loss = behavior_clone_loss - discrete_log_probs.sum(dim = -1).mean()
|
discrete_log_probs = discrete_log_probs.masked_fill(discrete_mask[..., None], 0.)
|
||||||
|
|
||||||
|
behavior_clone_loss = behavior_clone_loss - reduce(discrete_log_probs, 'mtp b t na -> b t', 'sum').mean()
|
||||||
|
|
||||||
if exists(continuous_log_probs):
|
if exists(continuous_log_probs):
|
||||||
behavior_clone_loss = behavior_clone_loss - continuous_log_probs.sum(dim = -1).mean()
|
continuous_log_probs = continuous_log_probs.masked_fill(continuous_mask[..., None], 0.)
|
||||||
|
|
||||||
|
behavior_clone_loss = behavior_clone_loss - reduce(continuous_log_probs, 'mtp b t na -> b t', 'sum').mean()
|
||||||
|
|
||||||
# gather losses
|
# gather losses
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.37"
|
version = "0.0.38"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -401,3 +401,25 @@ def test_mtp():
|
|||||||
|
|
||||||
assert action_targets.shape == (3, 15, 3, 2)
|
assert action_targets.shape == (3, 15, 3, 2)
|
||||||
assert mask.shape == (3, 15, 3)
|
assert mask.shape == (3, 15, 3)
|
||||||
|
|
||||||
|
from dreamer4.dreamer4 import ActionEmbedder
|
||||||
|
|
||||||
|
embedder = ActionEmbedder(
|
||||||
|
512,
|
||||||
|
num_discrete_actions = (4, 4),
|
||||||
|
num_continuous_actions = 2,
|
||||||
|
can_unembed = True,
|
||||||
|
num_unembed_preds = 8
|
||||||
|
)
|
||||||
|
|
||||||
|
discrete_actions = torch.randint(0, 4, (2, 3, 2))
|
||||||
|
continuous_actions = torch.randn(2, 3, 2)
|
||||||
|
|
||||||
|
action_embed = torch.randn(2, 16, 512)
|
||||||
|
discrete_logits, continuous_logits = embedder.unembed(action_embed)
|
||||||
|
|
||||||
|
assert discrete_logits.shape == (8, 2, 16, 8)
|
||||||
|
|
||||||
|
discrete_logits, continuous_logits = embedder.unembed(action_embed, pred_head_index = 0)
|
||||||
|
|
||||||
|
assert discrete_logits.shape == (2, 16, 8)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user