diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 9023cac..2cff56f 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1072,6 +1072,7 @@ class DynamicsWorldModel(Module): num_register_tokens = 8, # they claim register tokens led to better temporal consistency num_spatial_tokens = 2, # latents projected to greater number of spatial tokens num_latent_tokens = None, + num_agents = 1, num_tasks = 0, reward_encoder_kwargs: dict = dict(), depth = 4, @@ -1169,7 +1170,8 @@ class DynamicsWorldModel(Module): # they sum all the actions into a single token - self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2) + self.num_agents = num_agents + self.action_learned_embed = Parameter(torch.randn(self.num_agents, dim) * 1e-2) self.num_tasks = num_tasks self.task_embed = nn.Embedding(num_tasks, dim) @@ -1267,6 +1269,7 @@ class DynamicsWorldModel(Module): time_steps, num_steps = 4, batch_size = 1, + agent_index = 0, image_height = None, image_width = None, return_decoded_video = None, @@ -1344,7 +1347,9 @@ class DynamicsWorldModel(Module): # take care of the rewards by predicting on the agent token embedding on the last denoising step if return_rewards_per_frame: - reward_logits = self.to_reward_pred(agent_embed[:, -1:]) + one_agent_embed = agent_embed[:, -1:, agent_index] + + reward_logits = self.to_reward_pred(one_agent_embed) pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True) decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1) @@ -1501,7 +1506,7 @@ class DynamicsWorldModel(Module): # reinforcement learning related - agent_tokens = repeat(self.action_learned_embed, 'd -> b d', b = batch) + agent_tokens = repeat(self.action_learned_embed, '... d -> b ... d', b = batch) if exists(tasks): assert self.num_tasks > 0 @@ -1511,7 +1516,7 @@ class DynamicsWorldModel(Module): # handle agent tokens w/ actions and task embeds - agent_tokens = repeat(agent_tokens, 'b d -> b t d', t = time) + agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time) # maybe add a reward embedding to agent tokens @@ -1528,7 +1533,7 @@ class DynamicsWorldModel(Module): reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward - agent_tokens = agent_tokens + reward_embeds + agent_tokens = einx.add('b t ... d, b t', agent_tokens, reward_embeds) # main function, needs to be defined as such for shortcut training - additional calls for consistency loss @@ -1575,7 +1580,7 @@ class DynamicsWorldModel(Module): + num_spatial_tokens ) - space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality + space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = self.num_agents, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs) @@ -1715,6 +1720,10 @@ class DynamicsWorldModel(Module): reward_loss = self.zero if exists(rewards): + + if rewards.ndim == 2: # (b t) + encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean') + reward_pred = self.to_reward_pred(encoded_agent_tokens) reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)