multiple agent tokens sharing the same state

This commit is contained in:
lucidrains 2025-10-08 08:06:13 -07:00
parent c4e0f46528
commit 892654d442

View File

@ -1072,6 +1072,7 @@ class DynamicsWorldModel(Module):
num_register_tokens = 8, # they claim register tokens led to better temporal consistency num_register_tokens = 8, # they claim register tokens led to better temporal consistency
num_spatial_tokens = 2, # latents projected to greater number of spatial tokens num_spatial_tokens = 2, # latents projected to greater number of spatial tokens
num_latent_tokens = None, num_latent_tokens = None,
num_agents = 1,
num_tasks = 0, num_tasks = 0,
reward_encoder_kwargs: dict = dict(), reward_encoder_kwargs: dict = dict(),
depth = 4, depth = 4,
@ -1169,7 +1170,8 @@ class DynamicsWorldModel(Module):
# they sum all the actions into a single token # they sum all the actions into a single token
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2) self.num_agents = num_agents
self.action_learned_embed = Parameter(torch.randn(self.num_agents, dim) * 1e-2)
self.num_tasks = num_tasks self.num_tasks = num_tasks
self.task_embed = nn.Embedding(num_tasks, dim) self.task_embed = nn.Embedding(num_tasks, dim)
@ -1267,6 +1269,7 @@ class DynamicsWorldModel(Module):
time_steps, time_steps,
num_steps = 4, num_steps = 4,
batch_size = 1, batch_size = 1,
agent_index = 0,
image_height = None, image_height = None,
image_width = None, image_width = None,
return_decoded_video = None, return_decoded_video = None,
@ -1344,7 +1347,9 @@ class DynamicsWorldModel(Module):
# take care of the rewards by predicting on the agent token embedding on the last denoising step # take care of the rewards by predicting on the agent token embedding on the last denoising step
if return_rewards_per_frame: if return_rewards_per_frame:
reward_logits = self.to_reward_pred(agent_embed[:, -1:]) one_agent_embed = agent_embed[:, -1:, agent_index]
reward_logits = self.to_reward_pred(one_agent_embed)
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True) pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1) decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
@ -1501,7 +1506,7 @@ class DynamicsWorldModel(Module):
# reinforcement learning related # reinforcement learning related
agent_tokens = repeat(self.action_learned_embed, 'd -> b d', b = batch) agent_tokens = repeat(self.action_learned_embed, '... d -> b ... d', b = batch)
if exists(tasks): if exists(tasks):
assert self.num_tasks > 0 assert self.num_tasks > 0
@ -1511,7 +1516,7 @@ class DynamicsWorldModel(Module):
# handle agent tokens w/ actions and task embeds # handle agent tokens w/ actions and task embeds
agent_tokens = repeat(agent_tokens, 'b d -> b t d', t = time) agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
# maybe add a reward embedding to agent tokens # maybe add a reward embedding to agent tokens
@ -1528,7 +1533,7 @@ class DynamicsWorldModel(Module):
reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
agent_tokens = agent_tokens + reward_embeds agent_tokens = einx.add('b t ... d, b t', agent_tokens, reward_embeds)
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss # main function, needs to be defined as such for shortcut training - additional calls for consistency loss
@ -1575,7 +1580,7 @@ class DynamicsWorldModel(Module):
+ num_spatial_tokens + num_spatial_tokens
) )
space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = self.num_agents, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs) time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
@ -1715,6 +1720,10 @@ class DynamicsWorldModel(Module):
reward_loss = self.zero reward_loss = self.zero
if exists(rewards): if exists(rewards):
if rewards.ndim == 2: # (b t)
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
reward_pred = self.to_reward_pred(encoded_agent_tokens) reward_pred = self.to_reward_pred(encoded_agent_tokens)
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding) reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)