always store old agent embeds and old action parameters when possible

This commit is contained in:
lucidrains 2025-10-29 10:39:15 -07:00
parent 3547344312
commit a358a44a53
2 changed files with 7 additions and 6 deletions

View File

@ -2225,8 +2225,8 @@ class DynamicsWorldModel(Module):
max_timesteps = 16, max_timesteps = 16,
env_is_vectorized = False, env_is_vectorized = False,
use_time_kv_cache = True, use_time_kv_cache = True,
store_agent_embed = False, store_agent_embed = True,
store_old_action_unembeds = False, store_old_action_unembeds = True,
): ):
assert exists(self.video_tokenizer) assert exists(self.video_tokenizer)
@ -2391,7 +2391,7 @@ class DynamicsWorldModel(Module):
actions = (discrete_actions, continuous_actions), actions = (discrete_actions, continuous_actions),
log_probs = (discrete_log_probs, continuous_log_probs), log_probs = (discrete_log_probs, continuous_log_probs),
values = values, values = values,
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None, old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
agent_embed = acc_agent_embed if store_agent_embed else None, agent_embed = acc_agent_embed if store_agent_embed else None,
step_size = step_size, step_size = step_size,
agent_index = agent_index, agent_index = agent_index,
@ -2667,7 +2667,8 @@ class DynamicsWorldModel(Module):
return_agent_actions = False, return_agent_actions = False,
return_log_probs_and_values = False, return_log_probs_and_values = False,
return_time_kv_cache = False, return_time_kv_cache = False,
store_agent_embed = False store_agent_embed = True,
store_old_action_unembeds = True
): # (b t n d) | (b c t h w) ): # (b t n d) | (b c t h w)
@ -2947,7 +2948,7 @@ class DynamicsWorldModel(Module):
video = video, video = video,
proprio = proprio if has_proprio else None, proprio = proprio if has_proprio else None,
agent_embed = acc_agent_embed if store_agent_embed else None, agent_embed = acc_agent_embed if store_agent_embed else None,
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None, old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
step_size = step_size, step_size = step_size,
agent_index = agent_index, agent_index = agent_index,
lens = experience_lens, lens = experience_lens,

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.93" version = "0.0.94"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }