last commit for the day - take care of the task embed
This commit is contained in:
parent
fe99efecba
commit
f507afa0d3
@ -982,6 +982,7 @@ class DynamicsModel(Module):
|
|||||||
max_steps = 64, # K_max in paper
|
max_steps = 64, # K_max in paper
|
||||||
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
|
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
|
||||||
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
||||||
|
num_tasks = 0,
|
||||||
depth = 4,
|
depth = 4,
|
||||||
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
||||||
time_block_every = 4, # every 4th block is time
|
time_block_every = 4, # every 4th block is time
|
||||||
@ -1029,10 +1030,15 @@ class DynamicsModel(Module):
|
|||||||
self.pred_orig_latent = pred_orig_latent # x-space or v-space
|
self.pred_orig_latent = pred_orig_latent # x-space or v-space
|
||||||
self.loss_weight_fn = loss_weight_fn
|
self.loss_weight_fn = loss_weight_fn
|
||||||
|
|
||||||
|
# reinforcement related
|
||||||
|
|
||||||
# they sum all the actions into a single token
|
# they sum all the actions into a single token
|
||||||
|
|
||||||
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
|
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
|
||||||
|
|
||||||
|
self.num_task = num_tasks
|
||||||
|
self.task_embed = nn.Embedding(num_tasks, dim)
|
||||||
|
|
||||||
# calculate "space" seq len
|
# calculate "space" seq len
|
||||||
|
|
||||||
self.space_seq_len = (
|
self.space_seq_len = (
|
||||||
@ -1096,6 +1102,7 @@ class DynamicsModel(Module):
|
|||||||
latents = None, # (b t d)
|
latents = None, # (b t d)
|
||||||
signal_levels = None, # (b t)
|
signal_levels = None, # (b t)
|
||||||
step_sizes_log2 = None, # (b)
|
step_sizes_log2 = None, # (b)
|
||||||
|
tasks = None, # (b)
|
||||||
return_pred_only = False
|
return_pred_only = False
|
||||||
):
|
):
|
||||||
# handle video or latents
|
# handle video or latents
|
||||||
@ -1145,7 +1152,19 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
noised_latents = noise.lerp(latents, times)
|
noised_latents = noise.lerp(latents, times)
|
||||||
|
|
||||||
def get_prediction(noised_latents, signal_levels, step_sizes_log2):
|
# reinforcementnet learning related
|
||||||
|
|
||||||
|
agent_tokens = repeat(self.action_learned_embed, 'd -> b d', b = batch)
|
||||||
|
|
||||||
|
if exists(tasks):
|
||||||
|
assert self.num_tasks > 0
|
||||||
|
|
||||||
|
task_embeds = self.task_embed(tasks)
|
||||||
|
agent_tokens = agent_tokens + task_embeds
|
||||||
|
|
||||||
|
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||||
|
|
||||||
|
def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens):
|
||||||
# latents to spatial tokens
|
# latents to spatial tokens
|
||||||
|
|
||||||
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
||||||
@ -1155,8 +1174,6 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
|
registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
|
||||||
|
|
||||||
agent_token = repeat(self.action_learned_embed, 'd -> b t d', b = batch, t = time)
|
|
||||||
|
|
||||||
# determine signal + step size embed for their diffusion forcing + shortcut
|
# determine signal + step size embed for their diffusion forcing + shortcut
|
||||||
|
|
||||||
signal_embed = self.signal_levels_embed(signal_levels)
|
signal_embed = self.signal_levels_embed(signal_levels)
|
||||||
@ -1167,9 +1184,13 @@ class DynamicsModel(Module):
|
|||||||
flow_token = cat((signal_embed, step_size_embed), dim = -1)
|
flow_token = cat((signal_embed, step_size_embed), dim = -1)
|
||||||
flow_token = rearrange(flow_token, 'b t d -> b t d')
|
flow_token = rearrange(flow_token, 'b t d -> b t d')
|
||||||
|
|
||||||
|
# handle agent tokens w/ actions and task embeds
|
||||||
|
|
||||||
|
agent_tokens = repeat(agent_tokens, 'b d -> b t d', t = time)
|
||||||
|
|
||||||
# pack to tokens for attending
|
# pack to tokens for attending
|
||||||
|
|
||||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d')
|
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_tokens], 'b t * d')
|
||||||
|
|
||||||
# attend functions for space and time
|
# attend functions for space and time
|
||||||
|
|
||||||
@ -1211,7 +1232,7 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
# unpack
|
# unpack
|
||||||
|
|
||||||
flow_token, space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d')
|
flow_token, space_tokens, register_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||||
|
|
||||||
# pooling
|
# pooling
|
||||||
|
|
||||||
@ -1223,7 +1244,7 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
# forward the network
|
# forward the network
|
||||||
|
|
||||||
pred = get_prediction(noised_latents, signal_levels, step_sizes_log2)
|
pred = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens)
|
||||||
|
|
||||||
if return_pred_only:
|
if return_pred_only:
|
||||||
return pred
|
return pred
|
||||||
@ -1257,7 +1278,7 @@ class DynamicsModel(Module):
|
|||||||
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
||||||
half_step_size = 2 ** step_sizes_log2_minus_one
|
half_step_size = 2 ** step_sizes_log2_minus_one
|
||||||
|
|
||||||
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one)
|
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, agent_tokens)
|
||||||
|
|
||||||
# first derive b'
|
# first derive b'
|
||||||
|
|
||||||
@ -1273,7 +1294,7 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
# get second prediction for b''
|
# get second prediction for b''
|
||||||
|
|
||||||
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels + half_step_size[:, None], step_sizes_log2_minus_one)
|
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels + half_step_size[:, None], step_sizes_log2_minus_one, agent_tokens)
|
||||||
|
|
||||||
if is_v_space_pred:
|
if is_v_space_pred:
|
||||||
second_step_pred_flow = second_step_pred
|
second_step_pred_flow = second_step_pred
|
||||||
|
|||||||
@ -6,11 +6,13 @@ import torch
|
|||||||
@param('grouped_query_attn', (False, True))
|
@param('grouped_query_attn', (False, True))
|
||||||
@param('dynamics_with_video_input', (False, True))
|
@param('dynamics_with_video_input', (False, True))
|
||||||
@param('prob_no_shortcut_train', (None, 0., 1.))
|
@param('prob_no_shortcut_train', (None, 0., 1.))
|
||||||
|
@param('add_task_embeds', (False, True))
|
||||||
def test_e2e(
|
def test_e2e(
|
||||||
pred_orig_latent,
|
pred_orig_latent,
|
||||||
grouped_query_attn,
|
grouped_query_attn,
|
||||||
dynamics_with_video_input,
|
dynamics_with_video_input,
|
||||||
prob_no_shortcut_train
|
prob_no_shortcut_train,
|
||||||
|
add_task_embeds
|
||||||
):
|
):
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
||||||
|
|
||||||
@ -30,6 +32,7 @@ def test_e2e(
|
|||||||
video_tokenizer = tokenizer,
|
video_tokenizer = tokenizer,
|
||||||
dim_latent = 32,
|
dim_latent = 32,
|
||||||
max_steps = 64,
|
max_steps = 64,
|
||||||
|
num_tasks = 4,
|
||||||
pred_orig_latent = pred_orig_latent,
|
pred_orig_latent = pred_orig_latent,
|
||||||
attn_kwargs = dict(
|
attn_kwargs = dict(
|
||||||
heads = heads,
|
heads = heads,
|
||||||
@ -46,7 +49,17 @@ def test_e2e(
|
|||||||
else:
|
else:
|
||||||
dynamics_input = dict(latents = latents)
|
dynamics_input = dict(latents = latents)
|
||||||
|
|
||||||
flow_loss = dynamics(**dynamics_input, signal_levels = signal_levels, step_sizes_log2 = step_sizes_log2)
|
tasks = None
|
||||||
|
if add_task_embeds:
|
||||||
|
tasks = torch.randint(0, 4, (2,))
|
||||||
|
|
||||||
|
flow_loss = dynamics(
|
||||||
|
**dynamics_input,
|
||||||
|
tasks = tasks,
|
||||||
|
signal_levels = signal_levels,
|
||||||
|
step_sizes_log2 = step_sizes_log2
|
||||||
|
)
|
||||||
|
|
||||||
assert flow_loss.numel() == 1
|
assert flow_loss.numel() == 1
|
||||||
|
|
||||||
def test_symexp_two_hot():
|
def test_symexp_two_hot():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user