From 2a902eaaf71451a08740e943d17a0961456cbe32 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 16 Oct 2025 06:41:02 -0700 Subject: [PATCH] allow reward tokens to be attended to as state optionally, DT-esque. figure out multi-agent scenario once i get around to it --- dreamer4/dreamer4.py | 29 ++++++++++++++++++++--------- pyproject.toml | 2 +- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 1c4a988..6461a81 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1564,6 +1564,8 @@ class DynamicsWorldModel(Module): self.agent_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2) self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2) + self.reward_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2) + self.num_tasks = num_tasks self.task_embed = nn.Embedding(num_tasks, dim) @@ -1950,7 +1952,9 @@ class DynamicsWorldModel(Module): agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time) - # maybe add a reward embedding to agent tokens + # maybe reward tokens + + reward_tokens = agent_tokens[:, :, 0:0] if exists(rewards): two_hot_encoding = self.reward_encoder(rewards) @@ -1959,13 +1963,15 @@ class DynamicsWorldModel(Module): self.add_reward_embed_to_agent_token and (not self.training or not sample_prob(self.add_reward_embed_dropout)) # a bit of noise goes a long way ): - reward_embeds = self.reward_encoder.embed(two_hot_encoding) + assert self.num_agents == 1 - pop_last_reward = int(reward_embeds.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case + reward_tokens = self.reward_encoder.embed(two_hot_encoding) - reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward + pop_last_reward = int(reward_tokens.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case - agent_tokens = add('b t ... d, b t d', agent_tokens, reward_embeds) + reward_tokens = pad_at_dim(reward_tokens, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward + + reward_tokens = add('1 d, b t d', self.reward_learned_embed, reward_tokens) # maybe create the action tokens @@ -1992,7 +1998,7 @@ class DynamicsWorldModel(Module): # main function, needs to be defined as such for shortcut training - additional calls for consistency loss - def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, agent_tokens, return_agent_tokens = False): + def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False): # latents to spatial tokens space_tokens = self.latents_to_spatial_tokens(noised_latents) @@ -2005,6 +2011,10 @@ class DynamicsWorldModel(Module): num_action_tokens = 1 if not is_empty(action_tokens) else 0 + # reward tokens + + num_reward_tokens = 1 if not is_empty(reward_tokens) else 0 + # pack to tokens # [signal + step size embed] [latent space tokens] [register] [actions / agent] @@ -2035,6 +2045,7 @@ class DynamicsWorldModel(Module): space_seq_len = ( + 1 # signal + step + num_action_tokens # past action tokens - todo: account for multi-agent + + num_reward_tokens # maybe allow world model being fine-tuned in phase 3 to see rewards as state + self.num_agents # action / agent tokens + self.num_register_tokens + num_spatial_tokens @@ -2091,7 +2102,7 @@ class DynamicsWorldModel(Module): # forward the network - pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, agent_tokens, return_agent_tokens = True) + pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = True) if return_pred_only: if not return_agent_tokens: @@ -2128,7 +2139,7 @@ class DynamicsWorldModel(Module): step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2 half_step_size = 2 ** step_sizes_log2_minus_one - first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, action_tokens, agent_tokens) + first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, action_tokens, reward_tokens, agent_tokens) # first derive b' @@ -2147,7 +2158,7 @@ class DynamicsWorldModel(Module): # get second prediction for b'' signal_levels_plus_half_step = signal_levels + half_step_size[:, None] - second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, action_tokens, agent_tokens) + second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, action_tokens, reward_tokens, agent_tokens) if is_v_space_pred: second_step_pred_flow = second_step_pred diff --git a/pyproject.toml b/pyproject.toml index 343fb20..11fc5ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.20" +version = "0.0.21" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }