From ed0918c97423826f31654b293d3f3154805928b6 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 8 Oct 2025 08:13:16 -0700 Subject: [PATCH] prepare for evolution within dreams --- dreamer4/dreamer4.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 2cff56f..b156843 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1090,7 +1090,8 @@ class DynamicsWorldModel(Module): add_reward_embed_to_agent_token = False, add_reward_embed_dropout = 0.1, reward_loss_weight = 0.1, - value_head_mlp_depth = 3 + value_head_mlp_depth = 3, + num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037 ): super().__init__() @@ -1171,11 +1172,16 @@ class DynamicsWorldModel(Module): # they sum all the actions into a single token self.num_agents = num_agents - self.action_learned_embed = Parameter(torch.randn(self.num_agents, dim) * 1e-2) + self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2) self.num_tasks = num_tasks self.task_embed = nn.Embedding(num_tasks, dim) + # learned set of latent genes + + self.agent_has_genes = num_latent_genes > 0 + self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2) + # each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token @@ -1398,6 +1404,7 @@ class DynamicsWorldModel(Module): step_sizes_log2 = None, # () | (b) tasks = None, # (b) rewards = None, # (b t) + latent_gene_ids = None, # (b) return_pred_only = False, latent_is_noised = False, return_all_losses = False, @@ -1514,6 +1521,14 @@ class DynamicsWorldModel(Module): task_embeds = self.task_embed(tasks) agent_tokens = agent_tokens + task_embeds + # maybe evolution + + if exists(latent_gene_ids): + assert exists(self.latent_genes) + latent_genes = self.latent_genes[latent_gene_ids] + + agent_tokens = einx.add('b ... d, b d', agent_tokens, latent_genes) + # handle agent tokens w/ actions and task embeds agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)