with or without signed advantage

This commit is contained in:
lucidrains 2025-10-23 16:24:29 -07:00
parent fb3e026fe0
commit d0ffc6bfed
3 changed files with 16 additions and 5 deletions

View File

@ -2094,7 +2094,9 @@ class DynamicsWorldModel(Module):
experience: Experience,
policy_optim: Optimizer | None = None,
value_optim: Optimizer | None = None,
only_learn_policy_value_heads = True # in the paper, they do not finetune the entire dynamics model, they just learn the heads
only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
use_signed_advantage = True,
eps = 1e-6
):
latents = experience.latents
@ -2117,7 +2119,12 @@ class DynamicsWorldModel(Module):
# apparently they just use the sign of the advantage
# https://arxiv.org/abs/2410.04166v1
advantage = (returns - old_values).sign()
advantage = returns - old_values
if use_signed_advantage:
advantage = advantage.sign()
else:
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
# replay for the action logits and values

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.65"
version = "0.0.66"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -601,7 +601,11 @@ def test_cache_generate():
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
@param('vectorized', (False, True))
def test_online_rl(vectorized):
@param('use_signed_advantage', (False, True))
def test_online_rl(
vectorized,
use_signed_advantage
):
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
tokenizer = VideoTokenizer(
@ -637,7 +641,7 @@ def test_online_rl(vectorized):
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience, use_signed_advantage = use_signed_advantage)
actor_loss.backward()
critic_loss.backward()