with or without signed advantage
This commit is contained in:
parent
fb3e026fe0
commit
d0ffc6bfed
@ -2094,7 +2094,9 @@ class DynamicsWorldModel(Module):
|
||||
experience: Experience,
|
||||
policy_optim: Optimizer | None = None,
|
||||
value_optim: Optimizer | None = None,
|
||||
only_learn_policy_value_heads = True # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
||||
only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
||||
use_signed_advantage = True,
|
||||
eps = 1e-6
|
||||
):
|
||||
|
||||
latents = experience.latents
|
||||
@ -2117,7 +2119,12 @@ class DynamicsWorldModel(Module):
|
||||
# apparently they just use the sign of the advantage
|
||||
# https://arxiv.org/abs/2410.04166v1
|
||||
|
||||
advantage = (returns - old_values).sign()
|
||||
advantage = returns - old_values
|
||||
|
||||
if use_signed_advantage:
|
||||
advantage = advantage.sign()
|
||||
else:
|
||||
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
||||
|
||||
# replay for the action logits and values
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.65"
|
||||
version = "0.0.66"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -601,7 +601,11 @@ def test_cache_generate():
|
||||
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
||||
|
||||
@param('vectorized', (False, True))
|
||||
def test_online_rl(vectorized):
|
||||
@param('use_signed_advantage', (False, True))
|
||||
def test_online_rl(
|
||||
vectorized,
|
||||
use_signed_advantage
|
||||
):
|
||||
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
@ -637,7 +641,7 @@ def test_online_rl(vectorized):
|
||||
|
||||
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
||||
|
||||
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience)
|
||||
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience, use_signed_advantage = use_signed_advantage)
|
||||
|
||||
actor_loss.backward()
|
||||
critic_loss.backward()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user