From d0ffc6bfed905024858c437234f6facdfd484afa Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 23 Oct 2025 16:24:29 -0700 Subject: [PATCH] with or without signed advantage --- dreamer4/dreamer4.py | 11 +++++++++-- pyproject.toml | 2 +- tests/test_dreamer.py | 8 ++++++-- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index c1dd4aa..02e9822 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -2094,7 +2094,9 @@ class DynamicsWorldModel(Module): experience: Experience, policy_optim: Optimizer | None = None, value_optim: Optimizer | None = None, - only_learn_policy_value_heads = True # in the paper, they do not finetune the entire dynamics model, they just learn the heads + only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads + use_signed_advantage = True, + eps = 1e-6 ): latents = experience.latents @@ -2117,7 +2119,12 @@ class DynamicsWorldModel(Module): # apparently they just use the sign of the advantage # https://arxiv.org/abs/2410.04166v1 - advantage = (returns - old_values).sign() + advantage = returns - old_values + + if use_signed_advantage: + advantage = advantage.sign() + else: + advantage = F.layer_norm(advantage, advantage.shape, eps = eps) # replay for the action logits and values diff --git a/pyproject.toml b/pyproject.toml index c01c408..8b02c2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.65" +version = "0.0.66" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index cda0419..99355b9 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -601,7 +601,11 @@ def test_cache_generate(): generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) @param('vectorized', (False, True)) -def test_online_rl(vectorized): +@param('use_signed_advantage', (False, True)) +def test_online_rl( + vectorized, + use_signed_advantage +): from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer tokenizer = VideoTokenizer( @@ -637,7 +641,7 @@ def test_online_rl(vectorized): one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized) - actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience) + actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience, use_signed_advantage = use_signed_advantage) actor_loss.backward() critic_loss.backward()