always store old agent embeds and old action parameters when possible

This commit is contained in:
lucidrains 2025-10-29 10:39:15 -07:00
parent 3547344312
commit a358a44a53
2 changed files with 7 additions and 6 deletions

View File

@ -2225,8 +2225,8 @@ class DynamicsWorldModel(Module):
max_timesteps = 16,
env_is_vectorized = False,
use_time_kv_cache = True,
store_agent_embed = False,
store_old_action_unembeds = False,
store_agent_embed = True,
store_old_action_unembeds = True,
):
assert exists(self.video_tokenizer)
@ -2391,7 +2391,7 @@ class DynamicsWorldModel(Module):
actions = (discrete_actions, continuous_actions),
log_probs = (discrete_log_probs, continuous_log_probs),
values = values,
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
agent_embed = acc_agent_embed if store_agent_embed else None,
step_size = step_size,
agent_index = agent_index,
@ -2667,7 +2667,8 @@ class DynamicsWorldModel(Module):
return_agent_actions = False,
return_log_probs_and_values = False,
return_time_kv_cache = False,
store_agent_embed = False
store_agent_embed = True,
store_old_action_unembeds = True
): # (b t n d) | (b c t h w)
@ -2947,7 +2948,7 @@ class DynamicsWorldModel(Module):
video = video,
proprio = proprio if has_proprio else None,
agent_embed = acc_agent_embed if store_agent_embed else None,
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
step_size = step_size,
agent_index = agent_index,
lens = experience_lens,

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.93"
version = "0.0.94"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }