diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 655decf..659577c 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -982,6 +982,7 @@ class DynamicsModel(Module): max_steps = 64, # K_max in paper num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction) num_register_tokens = 8, # they claim register tokens led to better temporal consistency + num_tasks = 0, depth = 4, pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space) time_block_every = 4, # every 4th block is time @@ -1029,10 +1030,15 @@ class DynamicsModel(Module): self.pred_orig_latent = pred_orig_latent # x-space or v-space self.loss_weight_fn = loss_weight_fn + # reinforcement related + # they sum all the actions into a single token self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2) + self.num_task = num_tasks + self.task_embed = nn.Embedding(num_tasks, dim) + # calculate "space" seq len self.space_seq_len = ( @@ -1096,6 +1102,7 @@ class DynamicsModel(Module): latents = None, # (b t d) signal_levels = None, # (b t) step_sizes_log2 = None, # (b) + tasks = None, # (b) return_pred_only = False ): # handle video or latents @@ -1145,7 +1152,19 @@ class DynamicsModel(Module): noised_latents = noise.lerp(latents, times) - def get_prediction(noised_latents, signal_levels, step_sizes_log2): + # reinforcementnet learning related + + agent_tokens = repeat(self.action_learned_embed, 'd -> b d', b = batch) + + if exists(tasks): + assert self.num_tasks > 0 + + task_embeds = self.task_embed(tasks) + agent_tokens = agent_tokens + task_embeds + + # main function, needs to be defined as such for shortcut training - additional calls for consistency loss + + def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens): # latents to spatial tokens space_tokens = self.latents_to_spatial_tokens(noised_latents) @@ -1155,8 +1174,6 @@ class DynamicsModel(Module): registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time) - agent_token = repeat(self.action_learned_embed, 'd -> b t d', b = batch, t = time) - # determine signal + step size embed for their diffusion forcing + shortcut signal_embed = self.signal_levels_embed(signal_levels) @@ -1167,9 +1184,13 @@ class DynamicsModel(Module): flow_token = cat((signal_embed, step_size_embed), dim = -1) flow_token = rearrange(flow_token, 'b t d -> b t d') + # handle agent tokens w/ actions and task embeds + + agent_tokens = repeat(agent_tokens, 'b d -> b t d', t = time) + # pack to tokens for attending - tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d') + tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_tokens], 'b t * d') # attend functions for space and time @@ -1211,7 +1232,7 @@ class DynamicsModel(Module): # unpack - flow_token, space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d') + flow_token, space_tokens, register_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d') # pooling @@ -1223,7 +1244,7 @@ class DynamicsModel(Module): # forward the network - pred = get_prediction(noised_latents, signal_levels, step_sizes_log2) + pred = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens) if return_pred_only: return pred @@ -1257,7 +1278,7 @@ class DynamicsModel(Module): step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2 half_step_size = 2 ** step_sizes_log2_minus_one - first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one) + first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, agent_tokens) # first derive b' @@ -1273,7 +1294,7 @@ class DynamicsModel(Module): # get second prediction for b'' - second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels + half_step_size[:, None], step_sizes_log2_minus_one) + second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels + half_step_size[:, None], step_sizes_log2_minus_one, agent_tokens) if is_v_space_pred: second_step_pred_flow = second_step_pred diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 275a10f..bb0bb1b 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -6,11 +6,13 @@ import torch @param('grouped_query_attn', (False, True)) @param('dynamics_with_video_input', (False, True)) @param('prob_no_shortcut_train', (None, 0., 1.)) +@param('add_task_embeds', (False, True)) def test_e2e( pred_orig_latent, grouped_query_attn, dynamics_with_video_input, - prob_no_shortcut_train + prob_no_shortcut_train, + add_task_embeds ): from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel @@ -30,6 +32,7 @@ def test_e2e( video_tokenizer = tokenizer, dim_latent = 32, max_steps = 64, + num_tasks = 4, pred_orig_latent = pred_orig_latent, attn_kwargs = dict( heads = heads, @@ -46,7 +49,17 @@ def test_e2e( else: dynamics_input = dict(latents = latents) - flow_loss = dynamics(**dynamics_input, signal_levels = signal_levels, step_sizes_log2 = step_sizes_log2) + tasks = None + if add_task_embeds: + tasks = torch.randint(0, 4, (2,)) + + flow_loss = dynamics( + **dynamics_input, + tasks = tasks, + signal_levels = signal_levels, + step_sizes_log2 = step_sizes_log2 + ) + assert flow_loss.numel() == 1 def test_symexp_two_hot():