From 586379f2c8e6a8672c64ca12b1354bb7cfa6f417 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 29 Oct 2025 10:46:42 -0700 Subject: [PATCH] sum the kl div loss across number of actions by default for action embedder .kl_div --- dreamer4/dreamer4.py | 36 ++++++++++++++++++++++++------------ pyproject.toml | 2 +- tests/test_dreamer.py | 4 ++-- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 8e9b8f0..f3f260b 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -72,20 +72,22 @@ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips')) WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions')) +MaybeTensor = Tensor | None + @dataclass class Experience: latents: Tensor - video: Tensor | None = None - proprio: Tensor | None = None - agent_embed: Tensor | None = None + video: MaybeTensor = None + proprio: MaybeTensor = None + agent_embed: MaybeTensor = None rewards: Tensor | None = None - actions: tuple[Tensor, Tensor] | None = None - log_probs: tuple[Tensor, Tensor] | None = None - old_action_unembeds: tuple[Tensor, Tensor] | None = None - values: Tensor | None = None + actions: tuple[MaybeTensor, MaybeTensor] | None = None + log_probs: tuple[MaybeTensor, MaybeTensor] | None = None + old_action_unembeds: tuple[MaybeTensor, MaybeTensor] | None = None + values: MaybeTensor = None step_size: int | None = None - lens: Tensor | None = None - is_truncated: Tensor | None = None + lens: MaybeTensor = None + is_truncated: MaybeTensor = None agent_index: int = 0 is_from_world_model: bool = True @@ -850,9 +852,10 @@ class ActionEmbedder(Module): def kl_div( self, - src: tuple[Tensor | None, Tensor | None], - tgt: tuple[Tensor | None, Tensor | None] - ) -> tuple[Tensor | None, Tensor | None]: + src: tuple[MaybeTensor, MaybeTensor], + tgt: tuple[MaybeTensor, MaybeTensor], + reduce_across_num_actions = True + ) -> tuple[MaybeTensor, MaybeTensor]: src_discrete, src_continuous = src tgt_discrete, tgt_continuous = tgt @@ -894,6 +897,15 @@ class ActionEmbedder(Module): continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal) + # maybe reduce + + if reduce_across_num_actions: + if exists(discrete_kl_div): + discrete_kl_div = discrete_kl_div.sum(dim = -1) + + if exists(continuous_kl_div): + continuous_kl_div = continuous_kl_div.sum(dim = -1) + return discrete_kl_div, continuous_kl_div def forward( diff --git a/pyproject.toml b/pyproject.toml index 57fce7e..e68d8af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.94" +version = "0.0.95" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index be8ce14..b9f4de1 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -352,8 +352,8 @@ def test_action_embedder(): discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt)) - assert discrete_kl_div.shape == (2, 3, 2) - assert continuous_kl_div.shape == (2, 3, 2) + assert discrete_kl_div.shape == (2, 3) + assert continuous_kl_div.shape == (2, 3) # return discrete split by number of actions