multiple agent tokens sharing the same state
This commit is contained in:
parent
c4e0f46528
commit
892654d442
@ -1072,6 +1072,7 @@ class DynamicsWorldModel(Module):
|
||||
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
||||
num_spatial_tokens = 2, # latents projected to greater number of spatial tokens
|
||||
num_latent_tokens = None,
|
||||
num_agents = 1,
|
||||
num_tasks = 0,
|
||||
reward_encoder_kwargs: dict = dict(),
|
||||
depth = 4,
|
||||
@ -1169,7 +1170,8 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# they sum all the actions into a single token
|
||||
|
||||
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
|
||||
self.num_agents = num_agents
|
||||
self.action_learned_embed = Parameter(torch.randn(self.num_agents, dim) * 1e-2)
|
||||
|
||||
self.num_tasks = num_tasks
|
||||
self.task_embed = nn.Embedding(num_tasks, dim)
|
||||
@ -1267,6 +1269,7 @@ class DynamicsWorldModel(Module):
|
||||
time_steps,
|
||||
num_steps = 4,
|
||||
batch_size = 1,
|
||||
agent_index = 0,
|
||||
image_height = None,
|
||||
image_width = None,
|
||||
return_decoded_video = None,
|
||||
@ -1344,7 +1347,9 @@ class DynamicsWorldModel(Module):
|
||||
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
||||
|
||||
if return_rewards_per_frame:
|
||||
reward_logits = self.to_reward_pred(agent_embed[:, -1:])
|
||||
one_agent_embed = agent_embed[:, -1:, agent_index]
|
||||
|
||||
reward_logits = self.to_reward_pred(one_agent_embed)
|
||||
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
|
||||
|
||||
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
||||
@ -1501,7 +1506,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# reinforcement learning related
|
||||
|
||||
agent_tokens = repeat(self.action_learned_embed, 'd -> b d', b = batch)
|
||||
agent_tokens = repeat(self.action_learned_embed, '... d -> b ... d', b = batch)
|
||||
|
||||
if exists(tasks):
|
||||
assert self.num_tasks > 0
|
||||
@ -1511,7 +1516,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# handle agent tokens w/ actions and task embeds
|
||||
|
||||
agent_tokens = repeat(agent_tokens, 'b d -> b t d', t = time)
|
||||
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
||||
|
||||
# maybe add a reward embedding to agent tokens
|
||||
|
||||
@ -1528,7 +1533,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
|
||||
|
||||
agent_tokens = agent_tokens + reward_embeds
|
||||
agent_tokens = einx.add('b t ... d, b t', agent_tokens, reward_embeds)
|
||||
|
||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||
|
||||
@ -1575,7 +1580,7 @@ class DynamicsWorldModel(Module):
|
||||
+ num_spatial_tokens
|
||||
)
|
||||
|
||||
space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
||||
space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = self.num_agents, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
||||
|
||||
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
||||
|
||||
@ -1715,6 +1720,10 @@ class DynamicsWorldModel(Module):
|
||||
reward_loss = self.zero
|
||||
|
||||
if exists(rewards):
|
||||
|
||||
if rewards.ndim == 2: # (b t)
|
||||
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
|
||||
|
||||
reward_pred = self.to_reward_pred(encoded_agent_tokens)
|
||||
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user