allow reward tokens to be attended to as state optionally, DT-esque. figure out multi-agent scenario once i get around to it

This commit is contained in:
lucidrains 2025-10-16 06:41:02 -07:00
parent d28251e9f9
commit 2a902eaaf7
2 changed files with 21 additions and 10 deletions

View File

@ -1564,6 +1564,8 @@ class DynamicsWorldModel(Module):
self.agent_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
self.reward_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
self.num_tasks = num_tasks
self.task_embed = nn.Embedding(num_tasks, dim)
@ -1950,7 +1952,9 @@ class DynamicsWorldModel(Module):
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
# maybe add a reward embedding to agent tokens
# maybe reward tokens
reward_tokens = agent_tokens[:, :, 0:0]
if exists(rewards):
two_hot_encoding = self.reward_encoder(rewards)
@ -1959,13 +1963,15 @@ class DynamicsWorldModel(Module):
self.add_reward_embed_to_agent_token and
(not self.training or not sample_prob(self.add_reward_embed_dropout)) # a bit of noise goes a long way
):
reward_embeds = self.reward_encoder.embed(two_hot_encoding)
assert self.num_agents == 1
pop_last_reward = int(reward_embeds.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case
reward_tokens = self.reward_encoder.embed(two_hot_encoding)
reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
pop_last_reward = int(reward_tokens.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case
agent_tokens = add('b t ... d, b t d', agent_tokens, reward_embeds)
reward_tokens = pad_at_dim(reward_tokens, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
reward_tokens = add('1 d, b t d', self.reward_learned_embed, reward_tokens)
# maybe create the action tokens
@ -1992,7 +1998,7 @@ class DynamicsWorldModel(Module):
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, agent_tokens, return_agent_tokens = False):
def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False):
# latents to spatial tokens
space_tokens = self.latents_to_spatial_tokens(noised_latents)
@ -2005,6 +2011,10 @@ class DynamicsWorldModel(Module):
num_action_tokens = 1 if not is_empty(action_tokens) else 0
# reward tokens
num_reward_tokens = 1 if not is_empty(reward_tokens) else 0
# pack to tokens
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
@ -2035,6 +2045,7 @@ class DynamicsWorldModel(Module):
space_seq_len = (
+ 1 # signal + step
+ num_action_tokens # past action tokens - todo: account for multi-agent
+ num_reward_tokens # maybe allow world model being fine-tuned in phase 3 to see rewards as state
+ self.num_agents # action / agent tokens
+ self.num_register_tokens
+ num_spatial_tokens
@ -2091,7 +2102,7 @@ class DynamicsWorldModel(Module):
# forward the network
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, agent_tokens, return_agent_tokens = True)
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = True)
if return_pred_only:
if not return_agent_tokens:
@ -2128,7 +2139,7 @@ class DynamicsWorldModel(Module):
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
half_step_size = 2 ** step_sizes_log2_minus_one
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, action_tokens, agent_tokens)
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, action_tokens, reward_tokens, agent_tokens)
# first derive b'
@ -2147,7 +2158,7 @@ class DynamicsWorldModel(Module):
# get second prediction for b''
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, action_tokens, agent_tokens)
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, action_tokens, reward_tokens, agent_tokens)
if is_v_space_pred:
second_step_pred_flow = second_step_pred

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.20"
version = "0.0.21"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }