able to store the agent embeddings during rollouts with imagination or environment, for efficient policy optimization (but will also allow for finetuning world model for the heads)

This commit is contained in:
lucidrains 2025-10-28 09:02:26 -07:00
parent 789f091c63
commit d476fa7b14
3 changed files with 27 additions and 5 deletions

View File

@ -77,6 +77,7 @@ class Experience:
latents: Tensor latents: Tensor
video: Tensor | None = None video: Tensor | None = None
proprio: Tensor | None = None proprio: Tensor | None = None
agent_embed: Tensor | None = None,
rewards: Tensor | None = None rewards: Tensor | None = None
actions: tuple[Tensor, Tensor] | None = None actions: tuple[Tensor, Tensor] | None = None
log_probs: tuple[Tensor, Tensor] | None = None log_probs: tuple[Tensor, Tensor] | None = None
@ -2105,7 +2106,8 @@ class DynamicsWorldModel(Module):
step_size = 4, step_size = 4,
max_timesteps = 16, max_timesteps = 16,
env_is_vectorized = False, env_is_vectorized = False,
use_time_kv_cache = True use_time_kv_cache = True,
store_agent_embed = False
): ):
assert exists(self.video_tokenizer) assert exists(self.video_tokenizer)
@ -2130,6 +2132,8 @@ class DynamicsWorldModel(Module):
values = None values = None
latents = None latents = None
acc_agent_embed = None
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env # keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
is_terminated = full((batch,), False, device = device) is_terminated = full((batch,), False, device = device)
@ -2251,6 +2255,8 @@ class DynamicsWorldModel(Module):
video = cat((video, next_frame), dim = 2) video = cat((video, next_frame), dim = 2)
rewards = safe_cat((rewards, reward), dim = 1) rewards = safe_cat((rewards, reward), dim = 1)
acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
# package up one experience for learning # package up one experience for learning
batch, device = latents.shape[0], latents.device batch, device = latents.shape[0], latents.device
@ -2262,6 +2268,7 @@ class DynamicsWorldModel(Module):
actions = (discrete_actions, continuous_actions), actions = (discrete_actions, continuous_actions),
log_probs = (discrete_log_probs, continuous_log_probs), log_probs = (discrete_log_probs, continuous_log_probs),
values = values, values = values,
agent_embed = acc_agent_embed if store_agent_embed else None,
step_size = step_size, step_size = step_size,
agent_index = agent_index, agent_index = agent_index,
is_truncated = is_truncated, is_truncated = is_truncated,
@ -2491,6 +2498,7 @@ class DynamicsWorldModel(Module):
return_agent_actions = False, return_agent_actions = False,
return_log_probs_and_values = False, return_log_probs_and_values = False,
return_time_kv_cache = False, return_time_kv_cache = False,
store_agent_embed = False
): # (b t n d) | (b c t h w) ): # (b t n d) | (b c t h w)
@ -2543,6 +2551,10 @@ class DynamicsWorldModel(Module):
decoded_continuous_log_probs = None decoded_continuous_log_probs = None
decoded_values = None decoded_values = None
# maybe store agent embed
acc_agent_embed = None
# maybe return rewards # maybe return rewards
decoded_rewards = None decoded_rewards = None
@ -2651,6 +2663,10 @@ class DynamicsWorldModel(Module):
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1) decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
# maybe store agent embed
acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
# decode the agent actions if needed # decode the agent actions if needed
if return_agent_actions: if return_agent_actions:
@ -2747,6 +2763,7 @@ class DynamicsWorldModel(Module):
latents = latents, latents = latents,
video = video, video = video,
proprio = proprio if has_proprio else None, proprio = proprio if has_proprio else None,
agent_embed = acc_agent_embed if store_agent_embed else None,
step_size = step_size, step_size = step_size,
agent_index = agent_index, agent_index = agent_index,
lens = experience_lens, lens = experience_lens,

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.83" version = "0.0.85"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -614,11 +614,13 @@ def test_cache_generate():
@param('use_signed_advantage', (False, True)) @param('use_signed_advantage', (False, True))
@param('env_can_terminate', (False, True)) @param('env_can_terminate', (False, True))
@param('env_can_truncate', (False, True)) @param('env_can_truncate', (False, True))
@param('store_agent_embed', (False, True))
def test_online_rl( def test_online_rl(
vectorized, vectorized,
use_signed_advantage, use_signed_advantage,
env_can_terminate, env_can_terminate,
env_can_truncate env_can_truncate,
store_agent_embed
): ):
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
@ -664,11 +666,14 @@ def test_online_rl(
# manually # manually
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized) one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized) another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
combined_experience = combine_experiences([one_experience, another_experience]) combined_experience = combine_experiences([one_experience, another_experience])
if store_agent_embed:
assert exists(combined_experience.agent_embed)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage) actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage)
actor_loss.backward() actor_loss.backward()