multiple agent tokens sharing the same state
This commit is contained in:
parent
c4e0f46528
commit
892654d442
@ -1072,6 +1072,7 @@ class DynamicsWorldModel(Module):
|
|||||||
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
||||||
num_spatial_tokens = 2, # latents projected to greater number of spatial tokens
|
num_spatial_tokens = 2, # latents projected to greater number of spatial tokens
|
||||||
num_latent_tokens = None,
|
num_latent_tokens = None,
|
||||||
|
num_agents = 1,
|
||||||
num_tasks = 0,
|
num_tasks = 0,
|
||||||
reward_encoder_kwargs: dict = dict(),
|
reward_encoder_kwargs: dict = dict(),
|
||||||
depth = 4,
|
depth = 4,
|
||||||
@ -1169,7 +1170,8 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# they sum all the actions into a single token
|
# they sum all the actions into a single token
|
||||||
|
|
||||||
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
|
self.num_agents = num_agents
|
||||||
|
self.action_learned_embed = Parameter(torch.randn(self.num_agents, dim) * 1e-2)
|
||||||
|
|
||||||
self.num_tasks = num_tasks
|
self.num_tasks = num_tasks
|
||||||
self.task_embed = nn.Embedding(num_tasks, dim)
|
self.task_embed = nn.Embedding(num_tasks, dim)
|
||||||
@ -1267,6 +1269,7 @@ class DynamicsWorldModel(Module):
|
|||||||
time_steps,
|
time_steps,
|
||||||
num_steps = 4,
|
num_steps = 4,
|
||||||
batch_size = 1,
|
batch_size = 1,
|
||||||
|
agent_index = 0,
|
||||||
image_height = None,
|
image_height = None,
|
||||||
image_width = None,
|
image_width = None,
|
||||||
return_decoded_video = None,
|
return_decoded_video = None,
|
||||||
@ -1344,7 +1347,9 @@ class DynamicsWorldModel(Module):
|
|||||||
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
||||||
|
|
||||||
if return_rewards_per_frame:
|
if return_rewards_per_frame:
|
||||||
reward_logits = self.to_reward_pred(agent_embed[:, -1:])
|
one_agent_embed = agent_embed[:, -1:, agent_index]
|
||||||
|
|
||||||
|
reward_logits = self.to_reward_pred(one_agent_embed)
|
||||||
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
|
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
|
||||||
|
|
||||||
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
||||||
@ -1501,7 +1506,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# reinforcement learning related
|
# reinforcement learning related
|
||||||
|
|
||||||
agent_tokens = repeat(self.action_learned_embed, 'd -> b d', b = batch)
|
agent_tokens = repeat(self.action_learned_embed, '... d -> b ... d', b = batch)
|
||||||
|
|
||||||
if exists(tasks):
|
if exists(tasks):
|
||||||
assert self.num_tasks > 0
|
assert self.num_tasks > 0
|
||||||
@ -1511,7 +1516,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# handle agent tokens w/ actions and task embeds
|
# handle agent tokens w/ actions and task embeds
|
||||||
|
|
||||||
agent_tokens = repeat(agent_tokens, 'b d -> b t d', t = time)
|
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
||||||
|
|
||||||
# maybe add a reward embedding to agent tokens
|
# maybe add a reward embedding to agent tokens
|
||||||
|
|
||||||
@ -1528,7 +1533,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
|
reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
|
||||||
|
|
||||||
agent_tokens = agent_tokens + reward_embeds
|
agent_tokens = einx.add('b t ... d, b t', agent_tokens, reward_embeds)
|
||||||
|
|
||||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||||
|
|
||||||
@ -1575,7 +1580,7 @@ class DynamicsWorldModel(Module):
|
|||||||
+ num_spatial_tokens
|
+ num_spatial_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = self.num_agents, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
||||||
|
|
||||||
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
||||||
|
|
||||||
@ -1715,6 +1720,10 @@ class DynamicsWorldModel(Module):
|
|||||||
reward_loss = self.zero
|
reward_loss = self.zero
|
||||||
|
|
||||||
if exists(rewards):
|
if exists(rewards):
|
||||||
|
|
||||||
|
if rewards.ndim == 2: # (b t)
|
||||||
|
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
|
||||||
|
|
||||||
reward_pred = self.to_reward_pred(encoded_agent_tokens)
|
reward_pred = self.to_reward_pred(encoded_agent_tokens)
|
||||||
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
|
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user