allow reward tokens to be attended to as state optionally, DT-esque. figure out multi-agent scenario once i get around to it
This commit is contained in:
parent
d28251e9f9
commit
2a902eaaf7
@ -1564,6 +1564,8 @@ class DynamicsWorldModel(Module):
|
|||||||
self.agent_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
self.agent_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
||||||
self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
||||||
|
|
||||||
|
self.reward_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
||||||
|
|
||||||
self.num_tasks = num_tasks
|
self.num_tasks = num_tasks
|
||||||
self.task_embed = nn.Embedding(num_tasks, dim)
|
self.task_embed = nn.Embedding(num_tasks, dim)
|
||||||
|
|
||||||
@ -1950,7 +1952,9 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
||||||
|
|
||||||
# maybe add a reward embedding to agent tokens
|
# maybe reward tokens
|
||||||
|
|
||||||
|
reward_tokens = agent_tokens[:, :, 0:0]
|
||||||
|
|
||||||
if exists(rewards):
|
if exists(rewards):
|
||||||
two_hot_encoding = self.reward_encoder(rewards)
|
two_hot_encoding = self.reward_encoder(rewards)
|
||||||
@ -1959,13 +1963,15 @@ class DynamicsWorldModel(Module):
|
|||||||
self.add_reward_embed_to_agent_token and
|
self.add_reward_embed_to_agent_token and
|
||||||
(not self.training or not sample_prob(self.add_reward_embed_dropout)) # a bit of noise goes a long way
|
(not self.training or not sample_prob(self.add_reward_embed_dropout)) # a bit of noise goes a long way
|
||||||
):
|
):
|
||||||
reward_embeds = self.reward_encoder.embed(two_hot_encoding)
|
assert self.num_agents == 1
|
||||||
|
|
||||||
pop_last_reward = int(reward_embeds.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case
|
reward_tokens = self.reward_encoder.embed(two_hot_encoding)
|
||||||
|
|
||||||
reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
|
pop_last_reward = int(reward_tokens.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case
|
||||||
|
|
||||||
agent_tokens = add('b t ... d, b t d', agent_tokens, reward_embeds)
|
reward_tokens = pad_at_dim(reward_tokens, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
|
||||||
|
|
||||||
|
reward_tokens = add('1 d, b t d', self.reward_learned_embed, reward_tokens)
|
||||||
|
|
||||||
# maybe create the action tokens
|
# maybe create the action tokens
|
||||||
|
|
||||||
@ -1992,7 +1998,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||||
|
|
||||||
def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, agent_tokens, return_agent_tokens = False):
|
def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False):
|
||||||
# latents to spatial tokens
|
# latents to spatial tokens
|
||||||
|
|
||||||
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
||||||
@ -2005,6 +2011,10 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
num_action_tokens = 1 if not is_empty(action_tokens) else 0
|
num_action_tokens = 1 if not is_empty(action_tokens) else 0
|
||||||
|
|
||||||
|
# reward tokens
|
||||||
|
|
||||||
|
num_reward_tokens = 1 if not is_empty(reward_tokens) else 0
|
||||||
|
|
||||||
# pack to tokens
|
# pack to tokens
|
||||||
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
|
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
|
||||||
|
|
||||||
@ -2035,6 +2045,7 @@ class DynamicsWorldModel(Module):
|
|||||||
space_seq_len = (
|
space_seq_len = (
|
||||||
+ 1 # signal + step
|
+ 1 # signal + step
|
||||||
+ num_action_tokens # past action tokens - todo: account for multi-agent
|
+ num_action_tokens # past action tokens - todo: account for multi-agent
|
||||||
|
+ num_reward_tokens # maybe allow world model being fine-tuned in phase 3 to see rewards as state
|
||||||
+ self.num_agents # action / agent tokens
|
+ self.num_agents # action / agent tokens
|
||||||
+ self.num_register_tokens
|
+ self.num_register_tokens
|
||||||
+ num_spatial_tokens
|
+ num_spatial_tokens
|
||||||
@ -2091,7 +2102,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# forward the network
|
# forward the network
|
||||||
|
|
||||||
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, agent_tokens, return_agent_tokens = True)
|
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = True)
|
||||||
|
|
||||||
if return_pred_only:
|
if return_pred_only:
|
||||||
if not return_agent_tokens:
|
if not return_agent_tokens:
|
||||||
@ -2128,7 +2139,7 @@ class DynamicsWorldModel(Module):
|
|||||||
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
||||||
half_step_size = 2 ** step_sizes_log2_minus_one
|
half_step_size = 2 ** step_sizes_log2_minus_one
|
||||||
|
|
||||||
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, action_tokens, agent_tokens)
|
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, action_tokens, reward_tokens, agent_tokens)
|
||||||
|
|
||||||
# first derive b'
|
# first derive b'
|
||||||
|
|
||||||
@ -2147,7 +2158,7 @@ class DynamicsWorldModel(Module):
|
|||||||
# get second prediction for b''
|
# get second prediction for b''
|
||||||
|
|
||||||
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
||||||
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, action_tokens, agent_tokens)
|
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, action_tokens, reward_tokens, agent_tokens)
|
||||||
|
|
||||||
if is_v_space_pred:
|
if is_v_space_pred:
|
||||||
second_step_pred_flow = second_step_pred
|
second_step_pred_flow = second_step_pred
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.20"
|
version = "0.0.21"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user