From 0dba73428045dac0758b4cc6e51b59c514108699 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 17 Oct 2025 08:00:47 -0700 Subject: [PATCH] start the learning in dreams portion --- dreamer4/dreamer4.py | 126 ++++++++++++++++++++++++++++++++++++++---- pyproject.toml | 2 +- tests/test_dreamer.py | 20 +++++-- 3 files changed, 131 insertions(+), 17 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index a0fc1b8..8c88680 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -72,6 +72,8 @@ class WorldModelGenerations: actions: tuple[Tensor, Tensor] | None = None log_probs: tuple[Tensor, Tensor] | None = None values: Tensor | None = None + step_size: int | None = None + agent_index: int = 0 # helpers @@ -107,7 +109,7 @@ def is_empty(t): def log(t, eps = 1e-20): return t.clamp(min = eps).log() -def safe_cat(*tensors, dim): +def safe_cat(tensors, dim): tensors = [*filter(exists, tensors)] if len(tensors) == 0: @@ -648,7 +650,7 @@ class ActionEmbedder(Module): def calc_gae( rewards, values, - masks, + masks = None, gamma = 0.99, lam = 0.95, use_accelerated = None @@ -656,6 +658,9 @@ def calc_gae( assert values.shape[-1] == rewards.shape[-1] use_accelerated = default(use_accelerated, rewards.is_cuda) + if not exists(masks): + masks = torch.ones_like(values) + values = F.pad(values, (0, 1), value = 0.) values, values_next = values[..., :-1], values[..., 1:] @@ -1456,7 +1461,13 @@ class DynamicsWorldModel(Module): policy_head_mlp_depth = 3, behavior_clone_weight = 0.1, num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037 - num_residual_streams = 1 + num_residual_streams = 1, + gae_discount_factor = 0.997, + gae_lambda = 0.95, + ppo_eps_clip = 0.2, + value_clip = 0.4, + policy_entropy_weight = .01, + gae_use_accelerated = False ): super().__init__() @@ -1614,6 +1625,16 @@ class DynamicsWorldModel(Module): final_norm = False ) + # ppo related + + self.gae_use_accelerated = gae_use_accelerated + self.gae_discount_factor = gae_discount_factor + self.gae_lambda = gae_lambda + + self.ppo_eps_clip = ppo_eps_clip + self.value_clip = value_clip + self.policy_entropy_weight = value_clip + # zero self.register_buffer('zero', tensor(0.), persistent = False) @@ -1642,6 +1663,87 @@ class DynamicsWorldModel(Module): return list(set(params) - set(self.video_tokenizer.parameters())) + def learn_policy_from_generations( + self, + generation: WorldModelGenerations + ): + latents = generation.latents + actions = generation.actions + old_log_probs = generation.log_probs + old_values = generation.values + rewards = generation.rewards + + step_size = generation.step_size + agent_index = generation.agent_index + + assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization' + + returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated) + + # apparently they just use the sign of the advantage + # https://arxiv.org/abs/2410.04166v1 + + advantage = (returns - old_values).sign() + + # replay for the action logits and values + + discrete_actions, continuous_actions = actions + + _, agent_embed = self.forward( + latents = latents, + signal_levels = self.max_steps - 1, + step_sizes = step_size, + rewards = rewards, + discrete_actions = discrete_actions, + continuous_actions = continuous_actions, + latent_is_noised = True, + return_pred_only = True, + return_agent_tokens = True + ) + + agent_embed = agent_embed[..., agent_index, :] + + # ppo + + policy_embed = self.policy_head(agent_embed) + + log_probs = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions) + + # concat discrete and continuous actions into one for optimizing + + old_log_probs = safe_cat(old_log_probs, dim = -1) + log_probs = safe_cat(log_probs, dim = -1) + + ratio = (log_probs - old_log_probs).exp() + + clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip) + + advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions + + # clipped surrogate loss + + policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage) + + policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum') + policy_loss = policy_loss.mean() + + # value loss + + value_bins = self.value_head(agent_embed) + values = self.reward_encoder.bins_to_scalar_value(value_bins) + + clipped_values = old_values + (values - old_values).clamp(-self.value_clip, self.value_clip) + clipped_value_bins = self.reward_encoder(clipped_values) + + return_bins = self.reward_encoder(returns) + + value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none') + value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none') + + value_loss = torch.maximum(value_loss_1, value_loss_2).mean() + + return policy_loss, value_loss + @torch.no_grad() def generate( self, @@ -1761,8 +1863,8 @@ class DynamicsWorldModel(Module): sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed) - decoded_discrete_actions = safe_cat(decoded_discrete_actions, sampled_discrete_actions, dim = 1) - decoded_continuous_actions = safe_cat(decoded_continuous_actions, sampled_continuous_actions, dim = 1) + decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1) + decoded_continuous_actions = safe_cat((decoded_continuous_actions, sampled_continuous_actions), dim = 1) if return_log_probs_and_values: discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs( @@ -1771,13 +1873,13 @@ class DynamicsWorldModel(Module): continuous_targets = sampled_continuous_actions, ) - decoded_discrete_log_probs = safe_cat(decoded_discrete_log_probs, discrete_log_probs, dim = 1) - decoded_continuous_log_probs = safe_cat(decoded_continuous_log_probs, continuous_log_probs, dim = 1) + decoded_discrete_log_probs = safe_cat((decoded_discrete_log_probs, discrete_log_probs), dim = 1) + decoded_continuous_log_probs = safe_cat((decoded_continuous_log_probs, continuous_log_probs), dim = 1) value_bins = self.value_head(one_agent_embed) values = self.reward_encoder.bins_to_scalar_value(value_bins) - decoded_values = safe_cat(decoded_values, values, dim = 1) + decoded_values = safe_cat((decoded_values, values), dim = 1) # concat the denoised latent @@ -1812,7 +1914,12 @@ class DynamicsWorldModel(Module): # returning agent actions, rewards, and log probs + values for policy optimization - gen = WorldModelGenerations(latents = latents, video = video) + gen = WorldModelGenerations( + latents = latents, + video = video, + step_size = step_size, + agent_index = agent_index + ) if return_rewards_per_frame: gen.rewards = decoded_rewards @@ -2228,6 +2335,5 @@ class Dreamer(Module): self, video_tokenizer: VideoTokenizer, dynamics_model: DynamicsWorldModel, - discount_factor = 0.997 ): super().__init__() diff --git a/pyproject.toml b/pyproject.toml index 9c41677..df0cd54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.26" +version = "0.0.27" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 11ed491..2b6aba3 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -216,24 +216,32 @@ def test_action_with_world_model(): discrete_actions = torch.randint(0, 4, (1, 4, 1)) gen = dynamics.generate( - 10, + 16, + batch_size = 4, return_rewards_per_frame = True, return_agent_actions = True, return_log_probs_and_values = True ) - assert gen.video.shape == (1, 3, 10, 256, 256) - assert gen.rewards.shape == (1, 10) + assert gen.video.shape == (4, 3, 16, 256, 256) + assert gen.rewards.shape == (4, 16) discrete_actions, continuous_actions = gen.actions - assert discrete_actions.shape == (1, 10, 1) + assert discrete_actions.shape == (4, 16, 1) assert continuous_actions is None discrete_log_probs, _ = gen.log_probs - assert discrete_log_probs.shape == (1, 10, 1) - assert gen.values.shape == (1, 10) + assert discrete_log_probs.shape == (4, 16, 1) + assert gen.values.shape == (4, 16) + + # take a reinforcement learning step + + actor_loss, critic_loss = dynamics.learn_policy_from_generations(gen) + + actor_loss.backward(retain_graph = True) + critic_loss.backward() def test_action_embedder(): from dreamer4.dreamer4 import ActionEmbedder