allow reward tokens to be attended to as state optionally, DT-esque. figure out multi-agent scenario once i get around to it
This commit is contained in:
parent
d28251e9f9
commit
2a902eaaf7
@ -1564,6 +1564,8 @@ class DynamicsWorldModel(Module):
|
||||
self.agent_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
||||
self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
||||
|
||||
self.reward_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
||||
|
||||
self.num_tasks = num_tasks
|
||||
self.task_embed = nn.Embedding(num_tasks, dim)
|
||||
|
||||
@ -1950,7 +1952,9 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
||||
|
||||
# maybe add a reward embedding to agent tokens
|
||||
# maybe reward tokens
|
||||
|
||||
reward_tokens = agent_tokens[:, :, 0:0]
|
||||
|
||||
if exists(rewards):
|
||||
two_hot_encoding = self.reward_encoder(rewards)
|
||||
@ -1959,13 +1963,15 @@ class DynamicsWorldModel(Module):
|
||||
self.add_reward_embed_to_agent_token and
|
||||
(not self.training or not sample_prob(self.add_reward_embed_dropout)) # a bit of noise goes a long way
|
||||
):
|
||||
reward_embeds = self.reward_encoder.embed(two_hot_encoding)
|
||||
assert self.num_agents == 1
|
||||
|
||||
pop_last_reward = int(reward_embeds.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case
|
||||
reward_tokens = self.reward_encoder.embed(two_hot_encoding)
|
||||
|
||||
reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
|
||||
pop_last_reward = int(reward_tokens.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case
|
||||
|
||||
agent_tokens = add('b t ... d, b t d', agent_tokens, reward_embeds)
|
||||
reward_tokens = pad_at_dim(reward_tokens, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
|
||||
|
||||
reward_tokens = add('1 d, b t d', self.reward_learned_embed, reward_tokens)
|
||||
|
||||
# maybe create the action tokens
|
||||
|
||||
@ -1992,7 +1998,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||
|
||||
def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, agent_tokens, return_agent_tokens = False):
|
||||
def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False):
|
||||
# latents to spatial tokens
|
||||
|
||||
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
||||
@ -2005,6 +2011,10 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
num_action_tokens = 1 if not is_empty(action_tokens) else 0
|
||||
|
||||
# reward tokens
|
||||
|
||||
num_reward_tokens = 1 if not is_empty(reward_tokens) else 0
|
||||
|
||||
# pack to tokens
|
||||
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
|
||||
|
||||
@ -2035,6 +2045,7 @@ class DynamicsWorldModel(Module):
|
||||
space_seq_len = (
|
||||
+ 1 # signal + step
|
||||
+ num_action_tokens # past action tokens - todo: account for multi-agent
|
||||
+ num_reward_tokens # maybe allow world model being fine-tuned in phase 3 to see rewards as state
|
||||
+ self.num_agents # action / agent tokens
|
||||
+ self.num_register_tokens
|
||||
+ num_spatial_tokens
|
||||
@ -2091,7 +2102,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# forward the network
|
||||
|
||||
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, agent_tokens, return_agent_tokens = True)
|
||||
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = True)
|
||||
|
||||
if return_pred_only:
|
||||
if not return_agent_tokens:
|
||||
@ -2128,7 +2139,7 @@ class DynamicsWorldModel(Module):
|
||||
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
||||
half_step_size = 2 ** step_sizes_log2_minus_one
|
||||
|
||||
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, action_tokens, agent_tokens)
|
||||
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, action_tokens, reward_tokens, agent_tokens)
|
||||
|
||||
# first derive b'
|
||||
|
||||
@ -2147,7 +2158,7 @@ class DynamicsWorldModel(Module):
|
||||
# get second prediction for b''
|
||||
|
||||
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
||||
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, action_tokens, agent_tokens)
|
||||
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, action_tokens, reward_tokens, agent_tokens)
|
||||
|
||||
if is_v_space_pred:
|
||||
second_step_pred_flow = second_step_pred
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.20"
|
||||
version = "0.0.21"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user