take care of the loss normalization mentioned at the end of the first paragraph of section 3
This commit is contained in:
parent
79a1b1c46e
commit
374667d8a9
@ -66,7 +66,7 @@ LinearNoBias = partial(Linear, bias = False)
|
||||
|
||||
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
||||
|
||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone'))
|
||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
||||
|
||||
@dataclass
|
||||
class Experience:
|
||||
@ -210,6 +210,48 @@ def create_multi_token_prediction_targets(
|
||||
|
||||
# loss related
|
||||
|
||||
class LossNormalizer(Module):
|
||||
|
||||
# the authors mentioned the need for loss normalization in the dynamics transformer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_losses: int,
|
||||
beta = 0.95,
|
||||
eps = 1e-6
|
||||
):
|
||||
super().__init__()
|
||||
self.register_buffer('exp_avg_sq', torch.ones(num_losses))
|
||||
self.beta = beta
|
||||
self.eps = eps
|
||||
|
||||
def forward(
|
||||
self,
|
||||
losses: Tensor | list[Tensor] | dict[str, Tensor],
|
||||
update_ema = None
|
||||
):
|
||||
exp_avg_sq, beta = self.exp_avg_sq, self.beta
|
||||
update_ema = default(update_ema, self.training)
|
||||
|
||||
# get the rms value - as mentioned at the end of section 3 in the paper
|
||||
|
||||
rms = exp_avg_sq.sqrt()
|
||||
|
||||
if update_ema:
|
||||
decay = 1. - beta
|
||||
|
||||
# update the ema
|
||||
|
||||
exp_avg_sq.lerp_(losses.detach().square(), decay)
|
||||
|
||||
# then normalize
|
||||
|
||||
assert losses.numel() == rms.numel()
|
||||
|
||||
normed_losses = losses / rms.clamp(min = self.eps)
|
||||
|
||||
return normed_losses
|
||||
|
||||
class LPIPSLoss(Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -1570,11 +1612,13 @@ class DynamicsWorldModel(Module):
|
||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||
num_continuous_actions = 0,
|
||||
continuous_norm_stats = None,
|
||||
reward_loss_weight = 0.1,
|
||||
reward_loss_weight = 1.,
|
||||
multi_token_pred_len = 8,
|
||||
value_head_mlp_depth = 3,
|
||||
policy_head_mlp_depth = 3,
|
||||
behavior_clone_weight = 0.1,
|
||||
latent_flow_loss_weight = 1.,
|
||||
discrete_action_loss_weight = 1.,
|
||||
continuous_action_loss_weight = 1.,
|
||||
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||
num_residual_streams = 1,
|
||||
gae_discount_factor = 0.997,
|
||||
@ -1699,7 +1743,8 @@ class DynamicsWorldModel(Module):
|
||||
squeeze_unembed_preds = False
|
||||
)
|
||||
|
||||
self.behavior_clone_weight = behavior_clone_weight
|
||||
self.discrete_action_loss_weight = discrete_action_loss_weight
|
||||
self.continuous_action_loss_weight = continuous_action_loss_weight
|
||||
|
||||
# multi token prediction length
|
||||
|
||||
@ -1726,8 +1771,6 @@ class DynamicsWorldModel(Module):
|
||||
multi_token_pred_len
|
||||
)
|
||||
|
||||
self.reward_loss_weight = reward_loss_weight
|
||||
|
||||
# value head
|
||||
|
||||
self.value_head = create_mlp(
|
||||
@ -1761,7 +1804,15 @@ class DynamicsWorldModel(Module):
|
||||
self.value_clip = value_clip
|
||||
self.policy_entropy_weight = value_clip
|
||||
|
||||
# zero
|
||||
# loss related
|
||||
|
||||
self.flow_loss_normalizer = LossNormalizer(1)
|
||||
self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
|
||||
self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
||||
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
||||
|
||||
self.latent_flow_loss_weight = latent_flow_loss_weight
|
||||
self.reward_loss_weight = reward_loss_weight
|
||||
|
||||
self.register_buffer('zero', tensor(0.), persistent = False)
|
||||
|
||||
@ -2461,11 +2512,12 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
||||
|
||||
reward_loss = reward_losses.sum(dim = -1).mean() # they sum across the prediction steps - eq(9)
|
||||
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
||||
|
||||
# maybe autoregressive action loss
|
||||
|
||||
behavior_clone_loss = self.zero
|
||||
discrete_action_loss = self.zero
|
||||
continuous_action_loss = self.zero
|
||||
|
||||
if (
|
||||
self.num_agents == 1 and
|
||||
@ -2503,26 +2555,41 @@ class DynamicsWorldModel(Module):
|
||||
if exists(discrete_log_probs):
|
||||
discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
|
||||
|
||||
behavior_clone_loss = behavior_clone_loss - reduce(discrete_log_probs, 'mtp b t na -> b t', 'sum').mean()
|
||||
discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
|
||||
|
||||
if exists(continuous_log_probs):
|
||||
continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
|
||||
|
||||
behavior_clone_loss = behavior_clone_loss - reduce(continuous_log_probs, 'mtp b t na -> b t', 'sum').mean()
|
||||
continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
|
||||
|
||||
# gather losses
|
||||
# handle loss normalization
|
||||
|
||||
losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss)
|
||||
|
||||
if exists(self.flow_loss_normalizer):
|
||||
flow_loss = self.flow_loss_normalizer(flow_loss)
|
||||
|
||||
if exists(rewards) and exists(self.reward_loss_normalizer):
|
||||
reward_loss = self.reward_loss_normalizer(reward_loss)
|
||||
|
||||
if exists(discrete_actions) and exists(self.discrete_actions_loss_normalizer):
|
||||
discrete_action_loss = self.discrete_actions_loss_normalizer(discrete_action_loss)
|
||||
|
||||
if exists(continuous_actions) and exists(self.continuous_actions_loss_normalizer):
|
||||
continuous_action_loss = self.continuous_actions_loss_normalizer(continuous_action_loss)
|
||||
|
||||
# gather losses - they sum across the multi token prediction steps for rewards and actions - eq (9)
|
||||
|
||||
total_loss = (
|
||||
flow_loss +
|
||||
reward_loss * self.reward_loss_weight +
|
||||
behavior_clone_loss * self.behavior_clone_weight
|
||||
flow_loss * self.latent_flow_loss_weight +
|
||||
reward_loss.sum() * self.reward_loss_weight +
|
||||
discrete_action_loss.sum() * self.discrete_action_loss_weight +
|
||||
continuous_action_loss.sum() * self.continuous_action_loss_weight
|
||||
)
|
||||
|
||||
if not return_all_losses:
|
||||
return total_loss
|
||||
|
||||
losses = WorldModelLosses(flow_loss, reward_loss, behavior_clone_loss)
|
||||
|
||||
return total_loss, losses
|
||||
|
||||
# dreamer
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.40"
|
||||
version = "0.0.41"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -423,3 +423,16 @@ def test_mtp():
|
||||
discrete_logits, continuous_logits = embedder.unembed(action_embed, pred_head_index = 0)
|
||||
|
||||
assert discrete_logits.shape == (2, 16, 8)
|
||||
|
||||
def test_loss_normalizer():
|
||||
from torch import cat
|
||||
from dreamer4.dreamer4 import LossNormalizer
|
||||
|
||||
loss_normalizer = LossNormalizer(4, beta = 0.)
|
||||
|
||||
losses = torch.rand(4)
|
||||
|
||||
_ = loss_normalizer(losses)
|
||||
normed_losses = loss_normalizer(losses)
|
||||
|
||||
assert (normed_losses == 1.).all()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user