always store old agent embeds and old action parameters when possible
This commit is contained in:
parent
3547344312
commit
a358a44a53
@ -2225,8 +2225,8 @@ class DynamicsWorldModel(Module):
|
||||
max_timesteps = 16,
|
||||
env_is_vectorized = False,
|
||||
use_time_kv_cache = True,
|
||||
store_agent_embed = False,
|
||||
store_old_action_unembeds = False,
|
||||
store_agent_embed = True,
|
||||
store_old_action_unembeds = True,
|
||||
):
|
||||
assert exists(self.video_tokenizer)
|
||||
|
||||
@ -2391,7 +2391,7 @@ class DynamicsWorldModel(Module):
|
||||
actions = (discrete_actions, continuous_actions),
|
||||
log_probs = (discrete_log_probs, continuous_log_probs),
|
||||
values = values,
|
||||
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
|
||||
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
|
||||
agent_embed = acc_agent_embed if store_agent_embed else None,
|
||||
step_size = step_size,
|
||||
agent_index = agent_index,
|
||||
@ -2667,7 +2667,8 @@ class DynamicsWorldModel(Module):
|
||||
return_agent_actions = False,
|
||||
return_log_probs_and_values = False,
|
||||
return_time_kv_cache = False,
|
||||
store_agent_embed = False
|
||||
store_agent_embed = True,
|
||||
store_old_action_unembeds = True
|
||||
|
||||
): # (b t n d) | (b c t h w)
|
||||
|
||||
@ -2947,7 +2948,7 @@ class DynamicsWorldModel(Module):
|
||||
video = video,
|
||||
proprio = proprio if has_proprio else None,
|
||||
agent_embed = acc_agent_embed if store_agent_embed else None,
|
||||
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
|
||||
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
|
||||
step_size = step_size,
|
||||
agent_index = agent_index,
|
||||
lens = experience_lens,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.93"
|
||||
version = "0.0.94"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user