prepare for evolution within dreams

This commit is contained in:
lucidrains 2025-10-08 08:13:16 -07:00
parent 892654d442
commit ed0918c974

View File

@ -1090,7 +1090,8 @@ class DynamicsWorldModel(Module):
add_reward_embed_to_agent_token = False, add_reward_embed_to_agent_token = False,
add_reward_embed_dropout = 0.1, add_reward_embed_dropout = 0.1,
reward_loss_weight = 0.1, reward_loss_weight = 0.1,
value_head_mlp_depth = 3 value_head_mlp_depth = 3,
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
): ):
super().__init__() super().__init__()
@ -1171,11 +1172,16 @@ class DynamicsWorldModel(Module):
# they sum all the actions into a single token # they sum all the actions into a single token
self.num_agents = num_agents self.num_agents = num_agents
self.action_learned_embed = Parameter(torch.randn(self.num_agents, dim) * 1e-2) self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
self.num_tasks = num_tasks self.num_tasks = num_tasks
self.task_embed = nn.Embedding(num_tasks, dim) self.task_embed = nn.Embedding(num_tasks, dim)
# learned set of latent genes
self.agent_has_genes = num_latent_genes > 0
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token # each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
@ -1398,6 +1404,7 @@ class DynamicsWorldModel(Module):
step_sizes_log2 = None, # () | (b) step_sizes_log2 = None, # () | (b)
tasks = None, # (b) tasks = None, # (b)
rewards = None, # (b t) rewards = None, # (b t)
latent_gene_ids = None, # (b)
return_pred_only = False, return_pred_only = False,
latent_is_noised = False, latent_is_noised = False,
return_all_losses = False, return_all_losses = False,
@ -1514,6 +1521,14 @@ class DynamicsWorldModel(Module):
task_embeds = self.task_embed(tasks) task_embeds = self.task_embed(tasks)
agent_tokens = agent_tokens + task_embeds agent_tokens = agent_tokens + task_embeds
# maybe evolution
if exists(latent_gene_ids):
assert exists(self.latent_genes)
latent_genes = self.latent_genes[latent_gene_ids]
agent_tokens = einx.add('b ... d, b d', agent_tokens, latent_genes)
# handle agent tokens w/ actions and task embeds # handle agent tokens w/ actions and task embeds
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time) agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)