From fe99efecbaa782a05d7d391b24bf49544922f0a4 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 5 Oct 2025 11:17:36 -0700 Subject: [PATCH] make a first pass through the shortcut training logic (Frans et al from Berkeley) maintaining both v-space and x-space --- dreamer4/dreamer4.py | 227 ++++++++++++++++++++++++++++++------------ tests/test_dreamer.py | 14 +-- 2 files changed, 170 insertions(+), 71 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 68a5a08..655decf 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1,7 +1,8 @@ from __future__ import annotations import math -from math import ceil +from math import ceil, log2 +from random import random from collections import namedtuple from functools import partial @@ -70,6 +71,9 @@ def first(arr): def divisible_by(num, den): return (num % den) == 0 +def is_power_two(num): + return log2(num).is_integer() + # tensor helpers def pack_one(t, pattern): @@ -345,7 +349,9 @@ def apply_rotations( # handle gqa for rotary - if rotations.ndim > 2 and heads < rotations.shape[0]: + if rotations.ndim == 3 and rotations.shape[0] < heads: + rotary_heads = rotations.shape[0] + assert divisible_by(heads, rotary_heads) groups = heads // rotary_heads rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups) @@ -353,6 +359,8 @@ def apply_rotations( x1, x2 = t.chunk(2, dim = -1) rotated_half_t = cat((-x2, x1), dim = -1) + # rotate in the positions + rotated = t * rotations.cos() + rotated_half_t * rotations.sin() return rotated.type(dtype) @@ -971,8 +979,7 @@ class DynamicsModel(Module): dim, dim_latent, video_tokenizer: VideoTokenizer | None = None, - num_signal_levels = 500, - num_step_sizes = 32, + max_steps = 64, # K_max in paper num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction) num_register_tokens = 8, # they claim register tokens led to better temporal consistency depth = 4, @@ -985,7 +992,8 @@ class DynamicsModel(Module): attn_softclamp_value = 50., ff_kwargs: dict = dict(), loss_weight_fn: Callable = ramp_weight, - num_future_predictions = 8 # they do multi-token prediction of 8 steps forward + num_future_predictions = 8, # they do multi-token prediction of 8 steps forward + prob_no_shortcut_train = None # probability of no shortcut training, defaults to 1 / num_step_sizes ): super().__init__() @@ -1007,14 +1015,18 @@ class DynamicsModel(Module): assert divisible_by(dim, 2) dim_half = dim // 2 - self.num_signal_levels = num_signal_levels - self.num_step_sizes = num_step_sizes + assert is_power_two(max_steps), '`max_steps` must be a power of 2' + self.max_steps = max_steps + self.num_step_sizes_log2 = int(log2(max_steps)) - self.signal_levels_embed = nn.Embedding(num_signal_levels, dim_half) - self.step_sizes_embed = nn.Embedding(num_step_sizes, dim_half) + self.signal_levels_embed = nn.Embedding(max_steps, dim_half) + self.step_size_embed = nn.Embedding(self.num_step_sizes_log2, dim_half) # power of 2, so 1/1, 1/2, 1/4, 1/8 ... 1/Kmax - self.pred_orig_latent = pred_orig_latent + self.prob_no_shortcut_train = default(prob_no_shortcut_train, self.num_step_sizes_log2 ** -1.) + # loss related + + self.pred_orig_latent = pred_orig_latent # x-space or v-space self.loss_weight_fn = loss_weight_fn # they sum all the actions into a single token @@ -1083,7 +1095,8 @@ class DynamicsModel(Module): video = None, latents = None, # (b t d) signal_levels = None, # (b t) - step_sizes = None # (b t) + step_sizes_log2 = None, # (b) + return_pred_only = False ): # handle video or latents @@ -1094,114 +1107,198 @@ class DynamicsModel(Module): latents = self.video_tokenizer.tokenize(video) - time, device = latents.shape[1], latents.device + batch, time, device = *latents.shape[:2], latents.device # flow related - assert not (exists(signal_levels) ^ exists(step_sizes)) + assert not (exists(signal_levels) ^ exists(step_sizes_log2)) - flow_matching = exists(signal_levels) + # if neither signal levels or step sizes passed in + # generate them randomly for training - # flow matching if `signal_levels` passed in + no_shortcut_train = random() < self.prob_no_shortcut_train - if flow_matching: + if no_shortcut_train: + # if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min + # in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this - noise = torch.randn_like(latents) + step_sizes_log2 = torch.zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1 + signal_levels = torch.randint(0, self.max_steps, (batch, time), device = device) + else: - times = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1') + # now we follow eq (4) - orig_latents = latents + step_sizes_log2 = torch.randint(1, self.num_step_sizes_log2, (batch,), device = device) + num_step_sizes = 2 ** step_sizes_log2 - latents = noise.lerp(latents, times) + signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes - # allow for original velocity pred - # x-space as in paper is in else clause + # get the noise - if not self.pred_orig_latent: - pred_target = flow = latents - noise - else: - pred_target = latents + noise = torch.randn_like(latents) - # latents to spatial tokens + # times is from 0 to 1 - space_tokens = self.latents_to_spatial_tokens(latents) + times = rearrange(signal_levels.float() / self.max_steps, 'b t -> b t 1') - # pack to tokens - # [signal + step size embed] [latent space tokens] [register] [actions / agent] + # noise from 0 as noise to 1 as data - registers = repeat(self.register_tokens, 's d -> b t s d', b = latents.shape[0], t = latents.shape[1]) + noised_latents = noise.lerp(latents, times) - agent_token = repeat(self.action_learned_embed, 'd -> b t d', b = latents.shape[0], t = latents.shape[1]) + def get_prediction(noised_latents, signal_levels, step_sizes_log2): + # latents to spatial tokens - # determine signal + step size embed for their diffusion forcing + shortcut + space_tokens = self.latents_to_spatial_tokens(noised_latents) + + # pack to tokens + # [signal + step size embed] [latent space tokens] [register] [actions / agent] + + registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time) + + agent_token = repeat(self.action_learned_embed, 'd -> b t d', b = batch, t = time) + + # determine signal + step size embed for their diffusion forcing + shortcut - if exists(signal_levels): signal_embed = self.signal_levels_embed(signal_levels) - step_size_embed = self.step_sizes_embed(step_sizes) + + step_size_embed = self.step_size_embed(step_sizes_log2) + step_size_embed = repeat(step_size_embed, 'b ... -> b t ...', t = time) flow_token = cat((signal_embed, step_size_embed), dim = -1) flow_token = rearrange(flow_token, 'b t d -> b t d') - else: - flow_token = registers[..., 0:0, :] + # pack to tokens for attending - # pack to tokens for attending + tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d') - tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d') + # attend functions for space and time - # attend functions for space and time + seq_len = tokens.shape[1] - seq_len = tokens.shape[1] + use_flex = exists(flex_attention) and tokens.is_cuda - use_flex = exists(flex_attention) and tokens.is_cuda + attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device) - attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device) + space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality - space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality + time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs) - time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs) + # rotary - # rotary + rotary_pos_emb = self.time_rotary(time) - rotary_pos_emb = self.time_rotary(time) + # attention - # attention + for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time): - for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time): + tokens = pre_attn_rearrange(tokens) - tokens = pre_attn_rearrange(tokens) + # when is a axial time attention block, should be causal - # when is a axial time attention block, should be causal + attend_fn = time_attend if layer_is_time else space_attend - attend_fn = time_attend if layer_is_time else space_attend + layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None - layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None + # attention layer - # attention layer + tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens - tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens + tokens = post_attn_rearrange(tokens) - tokens = post_attn_rearrange(tokens) + # feedforward layer - # feedforward layer + tokens = ff(tokens) + tokens - tokens = ff(tokens) + tokens + # unpack - # unpack + flow_token, space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d') - flow_token, space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d') + # pooling - # pooling + pooled = reduce(space_tokens, 'b t s d -> b t d', 'mean') - pooled = reduce(space_tokens, 'b t s d -> b t d', 'mean') + pred = self.to_pred(pooled) - pred = self.to_pred(pooled) - - if not flow_matching: return pred + # forward the network + + pred = get_prediction(noised_latents, signal_levels, step_sizes_log2) + + if return_pred_only: + return pred + + # determine the target for the loss + + pred_target = None + + is_x_space = self.pred_orig_latent + is_v_space_pred = not self.pred_orig_latent + + maybe_shortcut_loss_weight = 1. + + if no_shortcut_train: + + # allow for original velocity pred + # x-space as in paper is in else clause + + if is_v_space_pred: + pred_target = flow = latents - noise + else: + pred_target = latents + else: + # shortcut training - Frans et al. https://arxiv.org/abs/2410.12557 + + # basically a consistency loss where you ensure quantity of two half steps equals one step + # dreamer then makes it works for x-space with some math + + get_prediction_no_grad = torch.no_grad()(get_prediction) + + step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2 + half_step_size = 2 ** step_sizes_log2_minus_one + + first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one) + + # first derive b' + + if is_v_space_pred: + first_step_pred_flow = first_step_pred + else: + first_times = signal_levels[..., None].float() / self.max_steps + first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times) + + # take a half step + + denoised_latent = noised_latents + first_step_pred_flow * (half_step_size[:, None, None] / self.max_steps) + + # get second prediction for b'' + + second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels + half_step_size[:, None], step_sizes_log2_minus_one) + + if is_v_space_pred: + second_step_pred_flow = second_step_pred + else: + second_times = signal_levels[..., None].float() / self.max_steps + second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times) + + # pred target is sg(b' + b'') / 2 + + pred_target = (first_step_pred_flow + second_step_pred_flow).detach() / 2 + + # need to convert x-space to v-space + + if is_x_space: + pred = (pred - noised_latents) / (1. - first_times) + maybe_shortcut_loss_weight = (1. - first_times) ** 2 + + # mse loss + losses = F.mse_loss(pred, pred_target, reduction = 'none') + losses = losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7) + + # loss weighting with their ramp function + if exists(self.loss_weight_fn): loss_weight = self.loss_weight_fn(times) losses = losses * loss_weight diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 6b22786..275a10f 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -5,10 +5,12 @@ import torch @param('pred_orig_latent', (False, True)) @param('grouped_query_attn', (False, True)) @param('dynamics_with_video_input', (False, True)) +@param('prob_no_shortcut_train', (None, 0., 1.)) def test_e2e( pred_orig_latent, grouped_query_attn, - dynamics_with_video_input + dynamics_with_video_input, + prob_no_shortcut_train ): from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel @@ -27,24 +29,24 @@ def test_e2e( 512, video_tokenizer = tokenizer, dim_latent = 32, - num_signal_levels = 500, - num_step_sizes = 32, + max_steps = 64, pred_orig_latent = pred_orig_latent, attn_kwargs = dict( heads = heads, query_heads = query_heads - ) + ), + prob_no_shortcut_train = prob_no_shortcut_train ) signal_levels = torch.randint(0, 500, (2, 4)) - step_sizes = torch.randint(0, 32, (2, 4)) + step_sizes_log2 = torch.randint(1, 6, (2,)) if dynamics_with_video_input: dynamics_input = dict(video = video) else: dynamics_input = dict(latents = latents) - flow_loss = dynamics(**dynamics_input, signal_levels = signal_levels, step_sizes = step_sizes) + flow_loss = dynamics(**dynamics_input, signal_levels = signal_levels, step_sizes_log2 = step_sizes_log2) assert flow_loss.numel() == 1 def test_symexp_two_hot():