take care of the loss normalization mentioned at the end of the first paragraph of section 3
This commit is contained in:
parent
79a1b1c46e
commit
374667d8a9
@ -66,7 +66,7 @@ LinearNoBias = partial(Linear, bias = False)
|
|||||||
|
|
||||||
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
||||||
|
|
||||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone'))
|
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Experience:
|
class Experience:
|
||||||
@ -210,6 +210,48 @@ def create_multi_token_prediction_targets(
|
|||||||
|
|
||||||
# loss related
|
# loss related
|
||||||
|
|
||||||
|
class LossNormalizer(Module):
|
||||||
|
|
||||||
|
# the authors mentioned the need for loss normalization in the dynamics transformer
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_losses: int,
|
||||||
|
beta = 0.95,
|
||||||
|
eps = 1e-6
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer('exp_avg_sq', torch.ones(num_losses))
|
||||||
|
self.beta = beta
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
losses: Tensor | list[Tensor] | dict[str, Tensor],
|
||||||
|
update_ema = None
|
||||||
|
):
|
||||||
|
exp_avg_sq, beta = self.exp_avg_sq, self.beta
|
||||||
|
update_ema = default(update_ema, self.training)
|
||||||
|
|
||||||
|
# get the rms value - as mentioned at the end of section 3 in the paper
|
||||||
|
|
||||||
|
rms = exp_avg_sq.sqrt()
|
||||||
|
|
||||||
|
if update_ema:
|
||||||
|
decay = 1. - beta
|
||||||
|
|
||||||
|
# update the ema
|
||||||
|
|
||||||
|
exp_avg_sq.lerp_(losses.detach().square(), decay)
|
||||||
|
|
||||||
|
# then normalize
|
||||||
|
|
||||||
|
assert losses.numel() == rms.numel()
|
||||||
|
|
||||||
|
normed_losses = losses / rms.clamp(min = self.eps)
|
||||||
|
|
||||||
|
return normed_losses
|
||||||
|
|
||||||
class LPIPSLoss(Module):
|
class LPIPSLoss(Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -1570,11 +1612,13 @@ class DynamicsWorldModel(Module):
|
|||||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||||
num_continuous_actions = 0,
|
num_continuous_actions = 0,
|
||||||
continuous_norm_stats = None,
|
continuous_norm_stats = None,
|
||||||
reward_loss_weight = 0.1,
|
reward_loss_weight = 1.,
|
||||||
multi_token_pred_len = 8,
|
multi_token_pred_len = 8,
|
||||||
value_head_mlp_depth = 3,
|
value_head_mlp_depth = 3,
|
||||||
policy_head_mlp_depth = 3,
|
policy_head_mlp_depth = 3,
|
||||||
behavior_clone_weight = 0.1,
|
latent_flow_loss_weight = 1.,
|
||||||
|
discrete_action_loss_weight = 1.,
|
||||||
|
continuous_action_loss_weight = 1.,
|
||||||
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||||
num_residual_streams = 1,
|
num_residual_streams = 1,
|
||||||
gae_discount_factor = 0.997,
|
gae_discount_factor = 0.997,
|
||||||
@ -1699,7 +1743,8 @@ class DynamicsWorldModel(Module):
|
|||||||
squeeze_unembed_preds = False
|
squeeze_unembed_preds = False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.behavior_clone_weight = behavior_clone_weight
|
self.discrete_action_loss_weight = discrete_action_loss_weight
|
||||||
|
self.continuous_action_loss_weight = continuous_action_loss_weight
|
||||||
|
|
||||||
# multi token prediction length
|
# multi token prediction length
|
||||||
|
|
||||||
@ -1726,8 +1771,6 @@ class DynamicsWorldModel(Module):
|
|||||||
multi_token_pred_len
|
multi_token_pred_len
|
||||||
)
|
)
|
||||||
|
|
||||||
self.reward_loss_weight = reward_loss_weight
|
|
||||||
|
|
||||||
# value head
|
# value head
|
||||||
|
|
||||||
self.value_head = create_mlp(
|
self.value_head = create_mlp(
|
||||||
@ -1761,7 +1804,15 @@ class DynamicsWorldModel(Module):
|
|||||||
self.value_clip = value_clip
|
self.value_clip = value_clip
|
||||||
self.policy_entropy_weight = value_clip
|
self.policy_entropy_weight = value_clip
|
||||||
|
|
||||||
# zero
|
# loss related
|
||||||
|
|
||||||
|
self.flow_loss_normalizer = LossNormalizer(1)
|
||||||
|
self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
|
||||||
|
self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
||||||
|
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
||||||
|
|
||||||
|
self.latent_flow_loss_weight = latent_flow_loss_weight
|
||||||
|
self.reward_loss_weight = reward_loss_weight
|
||||||
|
|
||||||
self.register_buffer('zero', tensor(0.), persistent = False)
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
||||||
|
|
||||||
@ -2461,11 +2512,12 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
||||||
|
|
||||||
reward_loss = reward_losses.sum(dim = -1).mean() # they sum across the prediction steps - eq(9)
|
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
||||||
|
|
||||||
# maybe autoregressive action loss
|
# maybe autoregressive action loss
|
||||||
|
|
||||||
behavior_clone_loss = self.zero
|
discrete_action_loss = self.zero
|
||||||
|
continuous_action_loss = self.zero
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.num_agents == 1 and
|
self.num_agents == 1 and
|
||||||
@ -2503,26 +2555,41 @@ class DynamicsWorldModel(Module):
|
|||||||
if exists(discrete_log_probs):
|
if exists(discrete_log_probs):
|
||||||
discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
|
discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
|
||||||
|
|
||||||
behavior_clone_loss = behavior_clone_loss - reduce(discrete_log_probs, 'mtp b t na -> b t', 'sum').mean()
|
discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
|
||||||
|
|
||||||
if exists(continuous_log_probs):
|
if exists(continuous_log_probs):
|
||||||
continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
|
continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
|
||||||
|
|
||||||
behavior_clone_loss = behavior_clone_loss - reduce(continuous_log_probs, 'mtp b t na -> b t', 'sum').mean()
|
continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
|
||||||
|
|
||||||
# gather losses
|
# handle loss normalization
|
||||||
|
|
||||||
|
losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss)
|
||||||
|
|
||||||
|
if exists(self.flow_loss_normalizer):
|
||||||
|
flow_loss = self.flow_loss_normalizer(flow_loss)
|
||||||
|
|
||||||
|
if exists(rewards) and exists(self.reward_loss_normalizer):
|
||||||
|
reward_loss = self.reward_loss_normalizer(reward_loss)
|
||||||
|
|
||||||
|
if exists(discrete_actions) and exists(self.discrete_actions_loss_normalizer):
|
||||||
|
discrete_action_loss = self.discrete_actions_loss_normalizer(discrete_action_loss)
|
||||||
|
|
||||||
|
if exists(continuous_actions) and exists(self.continuous_actions_loss_normalizer):
|
||||||
|
continuous_action_loss = self.continuous_actions_loss_normalizer(continuous_action_loss)
|
||||||
|
|
||||||
|
# gather losses - they sum across the multi token prediction steps for rewards and actions - eq (9)
|
||||||
|
|
||||||
total_loss = (
|
total_loss = (
|
||||||
flow_loss +
|
flow_loss * self.latent_flow_loss_weight +
|
||||||
reward_loss * self.reward_loss_weight +
|
reward_loss.sum() * self.reward_loss_weight +
|
||||||
behavior_clone_loss * self.behavior_clone_weight
|
discrete_action_loss.sum() * self.discrete_action_loss_weight +
|
||||||
|
continuous_action_loss.sum() * self.continuous_action_loss_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_all_losses:
|
if not return_all_losses:
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
losses = WorldModelLosses(flow_loss, reward_loss, behavior_clone_loss)
|
|
||||||
|
|
||||||
return total_loss, losses
|
return total_loss, losses
|
||||||
|
|
||||||
# dreamer
|
# dreamer
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.40"
|
version = "0.0.41"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -423,3 +423,16 @@ def test_mtp():
|
|||||||
discrete_logits, continuous_logits = embedder.unembed(action_embed, pred_head_index = 0)
|
discrete_logits, continuous_logits = embedder.unembed(action_embed, pred_head_index = 0)
|
||||||
|
|
||||||
assert discrete_logits.shape == (2, 16, 8)
|
assert discrete_logits.shape == (2, 16, 8)
|
||||||
|
|
||||||
|
def test_loss_normalizer():
|
||||||
|
from torch import cat
|
||||||
|
from dreamer4.dreamer4 import LossNormalizer
|
||||||
|
|
||||||
|
loss_normalizer = LossNormalizer(4, beta = 0.)
|
||||||
|
|
||||||
|
losses = torch.rand(4)
|
||||||
|
|
||||||
|
_ = loss_normalizer(losses)
|
||||||
|
normed_losses = loss_normalizer(losses)
|
||||||
|
|
||||||
|
assert (normed_losses == 1.).all()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user