diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 153a13e..87ca6a1 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -77,7 +77,7 @@ class Experience: latents: Tensor video: Tensor | None = None proprio: Tensor | None = None - agent_embed: Tensor | None = None, + agent_embed: Tensor | None = None rewards: Tensor | None = None actions: tuple[Tensor, Tensor] | None = None log_probs: tuple[Tensor, Tensor] | None = None @@ -2295,6 +2295,7 @@ class DynamicsWorldModel(Module): old_log_probs = experience.log_probs old_values = experience.values rewards = experience.rewards + agent_embeds = experience.agent_embed step_size = experience.step_size agent_index = experience.agent_index @@ -2374,32 +2375,38 @@ class DynamicsWorldModel(Module): advantage = F.layer_norm(advantage, advantage.shape, eps = eps) # replay for the action logits and values + # but only do so if fine tuning the entire world model for RL discrete_actions, continuous_actions = actions - with world_model_forward_context(): - _, (agent_embed, _) = self.forward( - latents = latents, - signal_levels = self.max_steps - 1, - step_sizes = step_size, - rewards = rewards, - discrete_actions = discrete_actions, - continuous_actions = continuous_actions, - latent_is_noised = True, - return_pred_only = True, - return_intermediates = True - ) + if ( + not only_learn_policy_value_heads or + not exists(agent_embeds) + ): - agent_embed = agent_embed[..., agent_index, :] + with world_model_forward_context(): + _, (agent_embeds, _) = self.forward( + latents = latents, + signal_levels = self.max_steps - 1, + step_sizes = step_size, + rewards = rewards, + discrete_actions = discrete_actions, + continuous_actions = continuous_actions, + latent_is_noised = True, + return_pred_only = True, + return_intermediates = True + ) + + agent_embeds = agent_embeds[..., agent_index, :] # maybe detach agent embed if only_learn_policy_value_heads: - agent_embed = agent_embed.detach() + agent_embeds = agent_embeds.detach() # ppo - policy_embed = self.policy_head(agent_embed) + policy_embed = self.policy_head(agent_embeds) log_probs, entropies = self.action_embedder.log_probs(policy_embed, pred_head_index = 0, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True) @@ -2448,7 +2455,7 @@ class DynamicsWorldModel(Module): # value loss - value_bins = self.value_head(agent_embed) + value_bins = self.value_head(agent_embeds) values = self.reward_encoder.bins_to_scalar_value(value_bins) clipped_values = old_values + (values - old_values).clamp(-self.value_clip, self.value_clip) diff --git a/pyproject.toml b/pyproject.toml index 1448b38..f7d03d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.85" +version = "0.0.87" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }