diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 8435d70..802942e 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1462,7 +1462,10 @@ class DynamicsModel(Module): if self.add_reward_embed_to_agent_token: reward_embeds = self.reward_encoder.embed(two_hot_encoding) - reward_embeds = pad_at_dim(reward_embeds, (1, -1), dim = -2, value = 0.) # shift as each agent token predicts the next reward + + pop_last_reward = int(reward_embeds.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case + + reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward agent_tokens = agent_tokens + reward_embeds