sum the kl div loss across number of actions by default for action embedder .kl_div

This commit is contained in:
lucidrains 2025-10-29 10:46:42 -07:00
parent a358a44a53
commit 586379f2c8
3 changed files with 27 additions and 15 deletions

View File

@ -72,20 +72,22 @@ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
MaybeTensor = Tensor | None
@dataclass
class Experience:
latents: Tensor
video: Tensor | None = None
proprio: Tensor | None = None
agent_embed: Tensor | None = None
video: MaybeTensor = None
proprio: MaybeTensor = None
agent_embed: MaybeTensor = None
rewards: Tensor | None = None
actions: tuple[Tensor, Tensor] | None = None
log_probs: tuple[Tensor, Tensor] | None = None
old_action_unembeds: tuple[Tensor, Tensor] | None = None
values: Tensor | None = None
actions: tuple[MaybeTensor, MaybeTensor] | None = None
log_probs: tuple[MaybeTensor, MaybeTensor] | None = None
old_action_unembeds: tuple[MaybeTensor, MaybeTensor] | None = None
values: MaybeTensor = None
step_size: int | None = None
lens: Tensor | None = None
is_truncated: Tensor | None = None
lens: MaybeTensor = None
is_truncated: MaybeTensor = None
agent_index: int = 0
is_from_world_model: bool = True
@ -850,9 +852,10 @@ class ActionEmbedder(Module):
def kl_div(
self,
src: tuple[Tensor | None, Tensor | None],
tgt: tuple[Tensor | None, Tensor | None]
) -> tuple[Tensor | None, Tensor | None]:
src: tuple[MaybeTensor, MaybeTensor],
tgt: tuple[MaybeTensor, MaybeTensor],
reduce_across_num_actions = True
) -> tuple[MaybeTensor, MaybeTensor]:
src_discrete, src_continuous = src
tgt_discrete, tgt_continuous = tgt
@ -894,6 +897,15 @@ class ActionEmbedder(Module):
continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
# maybe reduce
if reduce_across_num_actions:
if exists(discrete_kl_div):
discrete_kl_div = discrete_kl_div.sum(dim = -1)
if exists(continuous_kl_div):
continuous_kl_div = continuous_kl_div.sum(dim = -1)
return discrete_kl_div, continuous_kl_div
def forward(

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.94"
version = "0.0.95"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -352,8 +352,8 @@ def test_action_embedder():
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
assert discrete_kl_div.shape == (2, 3, 2)
assert continuous_kl_div.shape == (2, 3, 2)
assert discrete_kl_div.shape == (2, 3)
assert continuous_kl_div.shape == (2, 3)
# return discrete split by number of actions