able to store the agent embeddings during rollouts with imagination or environment, for efficient policy optimization (but will also allow for finetuning world model for the heads)

This commit is contained in:
lucidrains 2025-10-28 09:02:26 -07:00
parent 789f091c63
commit d476fa7b14
3 changed files with 27 additions and 5 deletions

View File

@ -77,6 +77,7 @@ class Experience:
latents: Tensor
video: Tensor | None = None
proprio: Tensor | None = None
agent_embed: Tensor | None = None,
rewards: Tensor | None = None
actions: tuple[Tensor, Tensor] | None = None
log_probs: tuple[Tensor, Tensor] | None = None
@ -2105,7 +2106,8 @@ class DynamicsWorldModel(Module):
step_size = 4,
max_timesteps = 16,
env_is_vectorized = False,
use_time_kv_cache = True
use_time_kv_cache = True,
store_agent_embed = False
):
assert exists(self.video_tokenizer)
@ -2130,6 +2132,8 @@ class DynamicsWorldModel(Module):
values = None
latents = None
acc_agent_embed = None
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
is_terminated = full((batch,), False, device = device)
@ -2251,6 +2255,8 @@ class DynamicsWorldModel(Module):
video = cat((video, next_frame), dim = 2)
rewards = safe_cat((rewards, reward), dim = 1)
acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
# package up one experience for learning
batch, device = latents.shape[0], latents.device
@ -2262,6 +2268,7 @@ class DynamicsWorldModel(Module):
actions = (discrete_actions, continuous_actions),
log_probs = (discrete_log_probs, continuous_log_probs),
values = values,
agent_embed = acc_agent_embed if store_agent_embed else None,
step_size = step_size,
agent_index = agent_index,
is_truncated = is_truncated,
@ -2491,6 +2498,7 @@ class DynamicsWorldModel(Module):
return_agent_actions = False,
return_log_probs_and_values = False,
return_time_kv_cache = False,
store_agent_embed = False
): # (b t n d) | (b c t h w)
@ -2543,6 +2551,10 @@ class DynamicsWorldModel(Module):
decoded_continuous_log_probs = None
decoded_values = None
# maybe store agent embed
acc_agent_embed = None
# maybe return rewards
decoded_rewards = None
@ -2651,6 +2663,10 @@ class DynamicsWorldModel(Module):
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
# maybe store agent embed
acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
# decode the agent actions if needed
if return_agent_actions:
@ -2747,6 +2763,7 @@ class DynamicsWorldModel(Module):
latents = latents,
video = video,
proprio = proprio if has_proprio else None,
agent_embed = acc_agent_embed if store_agent_embed else None,
step_size = step_size,
agent_index = agent_index,
lens = experience_lens,

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.83"
version = "0.0.85"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -614,11 +614,13 @@ def test_cache_generate():
@param('use_signed_advantage', (False, True))
@param('env_can_terminate', (False, True))
@param('env_can_truncate', (False, True))
@param('store_agent_embed', (False, True))
def test_online_rl(
vectorized,
use_signed_advantage,
env_can_terminate,
env_can_truncate
env_can_truncate,
store_agent_embed
):
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
@ -664,11 +666,14 @@ def test_online_rl(
# manually
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized)
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
combined_experience = combine_experiences([one_experience, another_experience])
if store_agent_embed:
assert exists(combined_experience.agent_embed)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage)
actor_loss.backward()