prepare for evolution within dreams
This commit is contained in:
parent
892654d442
commit
ed0918c974
@ -1090,7 +1090,8 @@ class DynamicsWorldModel(Module):
|
|||||||
add_reward_embed_to_agent_token = False,
|
add_reward_embed_to_agent_token = False,
|
||||||
add_reward_embed_dropout = 0.1,
|
add_reward_embed_dropout = 0.1,
|
||||||
reward_loss_weight = 0.1,
|
reward_loss_weight = 0.1,
|
||||||
value_head_mlp_depth = 3
|
value_head_mlp_depth = 3,
|
||||||
|
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -1171,11 +1172,16 @@ class DynamicsWorldModel(Module):
|
|||||||
# they sum all the actions into a single token
|
# they sum all the actions into a single token
|
||||||
|
|
||||||
self.num_agents = num_agents
|
self.num_agents = num_agents
|
||||||
self.action_learned_embed = Parameter(torch.randn(self.num_agents, dim) * 1e-2)
|
self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
||||||
|
|
||||||
self.num_tasks = num_tasks
|
self.num_tasks = num_tasks
|
||||||
self.task_embed = nn.Embedding(num_tasks, dim)
|
self.task_embed = nn.Embedding(num_tasks, dim)
|
||||||
|
|
||||||
|
# learned set of latent genes
|
||||||
|
|
||||||
|
self.agent_has_genes = num_latent_genes > 0
|
||||||
|
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
|
||||||
|
|
||||||
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
||||||
|
|
||||||
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
||||||
@ -1398,6 +1404,7 @@ class DynamicsWorldModel(Module):
|
|||||||
step_sizes_log2 = None, # () | (b)
|
step_sizes_log2 = None, # () | (b)
|
||||||
tasks = None, # (b)
|
tasks = None, # (b)
|
||||||
rewards = None, # (b t)
|
rewards = None, # (b t)
|
||||||
|
latent_gene_ids = None, # (b)
|
||||||
return_pred_only = False,
|
return_pred_only = False,
|
||||||
latent_is_noised = False,
|
latent_is_noised = False,
|
||||||
return_all_losses = False,
|
return_all_losses = False,
|
||||||
@ -1514,6 +1521,14 @@ class DynamicsWorldModel(Module):
|
|||||||
task_embeds = self.task_embed(tasks)
|
task_embeds = self.task_embed(tasks)
|
||||||
agent_tokens = agent_tokens + task_embeds
|
agent_tokens = agent_tokens + task_embeds
|
||||||
|
|
||||||
|
# maybe evolution
|
||||||
|
|
||||||
|
if exists(latent_gene_ids):
|
||||||
|
assert exists(self.latent_genes)
|
||||||
|
latent_genes = self.latent_genes[latent_gene_ids]
|
||||||
|
|
||||||
|
agent_tokens = einx.add('b ... d, b d', agent_tokens, latent_genes)
|
||||||
|
|
||||||
# handle agent tokens w/ actions and task embeds
|
# handle agent tokens w/ actions and task embeds
|
||||||
|
|
||||||
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user