always store old agent embeds and old action parameters when possible
This commit is contained in:
parent
3547344312
commit
a358a44a53
@ -2225,8 +2225,8 @@ class DynamicsWorldModel(Module):
|
|||||||
max_timesteps = 16,
|
max_timesteps = 16,
|
||||||
env_is_vectorized = False,
|
env_is_vectorized = False,
|
||||||
use_time_kv_cache = True,
|
use_time_kv_cache = True,
|
||||||
store_agent_embed = False,
|
store_agent_embed = True,
|
||||||
store_old_action_unembeds = False,
|
store_old_action_unembeds = True,
|
||||||
):
|
):
|
||||||
assert exists(self.video_tokenizer)
|
assert exists(self.video_tokenizer)
|
||||||
|
|
||||||
@ -2391,7 +2391,7 @@ class DynamicsWorldModel(Module):
|
|||||||
actions = (discrete_actions, continuous_actions),
|
actions = (discrete_actions, continuous_actions),
|
||||||
log_probs = (discrete_log_probs, continuous_log_probs),
|
log_probs = (discrete_log_probs, continuous_log_probs),
|
||||||
values = values,
|
values = values,
|
||||||
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
|
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
|
||||||
agent_embed = acc_agent_embed if store_agent_embed else None,
|
agent_embed = acc_agent_embed if store_agent_embed else None,
|
||||||
step_size = step_size,
|
step_size = step_size,
|
||||||
agent_index = agent_index,
|
agent_index = agent_index,
|
||||||
@ -2667,7 +2667,8 @@ class DynamicsWorldModel(Module):
|
|||||||
return_agent_actions = False,
|
return_agent_actions = False,
|
||||||
return_log_probs_and_values = False,
|
return_log_probs_and_values = False,
|
||||||
return_time_kv_cache = False,
|
return_time_kv_cache = False,
|
||||||
store_agent_embed = False
|
store_agent_embed = True,
|
||||||
|
store_old_action_unembeds = True
|
||||||
|
|
||||||
): # (b t n d) | (b c t h w)
|
): # (b t n d) | (b c t h w)
|
||||||
|
|
||||||
@ -2947,7 +2948,7 @@ class DynamicsWorldModel(Module):
|
|||||||
video = video,
|
video = video,
|
||||||
proprio = proprio if has_proprio else None,
|
proprio = proprio if has_proprio else None,
|
||||||
agent_embed = acc_agent_embed if store_agent_embed else None,
|
agent_embed = acc_agent_embed if store_agent_embed else None,
|
||||||
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
|
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
|
||||||
step_size = step_size,
|
step_size = step_size,
|
||||||
agent_index = agent_index,
|
agent_index = agent_index,
|
||||||
lens = experience_lens,
|
lens = experience_lens,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.93"
|
version = "0.0.94"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user