prepare for evolution within dreams
This commit is contained in:
parent
892654d442
commit
ed0918c974
@ -1090,7 +1090,8 @@ class DynamicsWorldModel(Module):
|
||||
add_reward_embed_to_agent_token = False,
|
||||
add_reward_embed_dropout = 0.1,
|
||||
reward_loss_weight = 0.1,
|
||||
value_head_mlp_depth = 3
|
||||
value_head_mlp_depth = 3,
|
||||
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -1171,11 +1172,16 @@ class DynamicsWorldModel(Module):
|
||||
# they sum all the actions into a single token
|
||||
|
||||
self.num_agents = num_agents
|
||||
self.action_learned_embed = Parameter(torch.randn(self.num_agents, dim) * 1e-2)
|
||||
self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
||||
|
||||
self.num_tasks = num_tasks
|
||||
self.task_embed = nn.Embedding(num_tasks, dim)
|
||||
|
||||
# learned set of latent genes
|
||||
|
||||
self.agent_has_genes = num_latent_genes > 0
|
||||
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
|
||||
|
||||
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
||||
|
||||
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
||||
@ -1398,6 +1404,7 @@ class DynamicsWorldModel(Module):
|
||||
step_sizes_log2 = None, # () | (b)
|
||||
tasks = None, # (b)
|
||||
rewards = None, # (b t)
|
||||
latent_gene_ids = None, # (b)
|
||||
return_pred_only = False,
|
||||
latent_is_noised = False,
|
||||
return_all_losses = False,
|
||||
@ -1514,6 +1521,14 @@ class DynamicsWorldModel(Module):
|
||||
task_embeds = self.task_embed(tasks)
|
||||
agent_tokens = agent_tokens + task_embeds
|
||||
|
||||
# maybe evolution
|
||||
|
||||
if exists(latent_gene_ids):
|
||||
assert exists(self.latent_genes)
|
||||
latent_genes = self.latent_genes[latent_gene_ids]
|
||||
|
||||
agent_tokens = einx.add('b ... d, b d', agent_tokens, latent_genes)
|
||||
|
||||
# handle agent tokens w/ actions and task embeds
|
||||
|
||||
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user