make a first pass through the shortcut training logic (Frans et al from Berkeley) maintaining both v-space and x-space

This commit is contained in:
lucidrains 2025-10-05 11:17:36 -07:00
parent 971637673b
commit fe99efecba
2 changed files with 170 additions and 71 deletions

View File

@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
import math import math
from math import ceil from math import ceil, log2
from random import random
from collections import namedtuple from collections import namedtuple
from functools import partial from functools import partial
@ -70,6 +71,9 @@ def first(arr):
def divisible_by(num, den): def divisible_by(num, den):
return (num % den) == 0 return (num % den) == 0
def is_power_two(num):
return log2(num).is_integer()
# tensor helpers # tensor helpers
def pack_one(t, pattern): def pack_one(t, pattern):
@ -345,7 +349,9 @@ def apply_rotations(
# handle gqa for rotary # handle gqa for rotary
if rotations.ndim > 2 and heads < rotations.shape[0]: if rotations.ndim == 3 and rotations.shape[0] < heads:
rotary_heads = rotations.shape[0]
assert divisible_by(heads, rotary_heads) assert divisible_by(heads, rotary_heads)
groups = heads // rotary_heads groups = heads // rotary_heads
rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups) rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups)
@ -353,6 +359,8 @@ def apply_rotations(
x1, x2 = t.chunk(2, dim = -1) x1, x2 = t.chunk(2, dim = -1)
rotated_half_t = cat((-x2, x1), dim = -1) rotated_half_t = cat((-x2, x1), dim = -1)
# rotate in the positions
rotated = t * rotations.cos() + rotated_half_t * rotations.sin() rotated = t * rotations.cos() + rotated_half_t * rotations.sin()
return rotated.type(dtype) return rotated.type(dtype)
@ -971,8 +979,7 @@ class DynamicsModel(Module):
dim, dim,
dim_latent, dim_latent,
video_tokenizer: VideoTokenizer | None = None, video_tokenizer: VideoTokenizer | None = None,
num_signal_levels = 500, max_steps = 64, # K_max in paper
num_step_sizes = 32,
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction) num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
num_register_tokens = 8, # they claim register tokens led to better temporal consistency num_register_tokens = 8, # they claim register tokens led to better temporal consistency
depth = 4, depth = 4,
@ -985,7 +992,8 @@ class DynamicsModel(Module):
attn_softclamp_value = 50., attn_softclamp_value = 50.,
ff_kwargs: dict = dict(), ff_kwargs: dict = dict(),
loss_weight_fn: Callable = ramp_weight, loss_weight_fn: Callable = ramp_weight,
num_future_predictions = 8 # they do multi-token prediction of 8 steps forward num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
prob_no_shortcut_train = None # probability of no shortcut training, defaults to 1 / num_step_sizes
): ):
super().__init__() super().__init__()
@ -1007,14 +1015,18 @@ class DynamicsModel(Module):
assert divisible_by(dim, 2) assert divisible_by(dim, 2)
dim_half = dim // 2 dim_half = dim // 2
self.num_signal_levels = num_signal_levels assert is_power_two(max_steps), '`max_steps` must be a power of 2'
self.num_step_sizes = num_step_sizes self.max_steps = max_steps
self.num_step_sizes_log2 = int(log2(max_steps))
self.signal_levels_embed = nn.Embedding(num_signal_levels, dim_half) self.signal_levels_embed = nn.Embedding(max_steps, dim_half)
self.step_sizes_embed = nn.Embedding(num_step_sizes, dim_half) self.step_size_embed = nn.Embedding(self.num_step_sizes_log2, dim_half) # power of 2, so 1/1, 1/2, 1/4, 1/8 ... 1/Kmax
self.pred_orig_latent = pred_orig_latent self.prob_no_shortcut_train = default(prob_no_shortcut_train, self.num_step_sizes_log2 ** -1.)
# loss related
self.pred_orig_latent = pred_orig_latent # x-space or v-space
self.loss_weight_fn = loss_weight_fn self.loss_weight_fn = loss_weight_fn
# they sum all the actions into a single token # they sum all the actions into a single token
@ -1083,7 +1095,8 @@ class DynamicsModel(Module):
video = None, video = None,
latents = None, # (b t d) latents = None, # (b t d)
signal_levels = None, # (b t) signal_levels = None, # (b t)
step_sizes = None # (b t) step_sizes_log2 = None, # (b)
return_pred_only = False
): ):
# handle video or latents # handle video or latents
@ -1094,57 +1107,66 @@ class DynamicsModel(Module):
latents = self.video_tokenizer.tokenize(video) latents = self.video_tokenizer.tokenize(video)
time, device = latents.shape[1], latents.device batch, time, device = *latents.shape[:2], latents.device
# flow related # flow related
assert not (exists(signal_levels) ^ exists(step_sizes)) assert not (exists(signal_levels) ^ exists(step_sizes_log2))
flow_matching = exists(signal_levels) # if neither signal levels or step sizes passed in
# generate them randomly for training
# flow matching if `signal_levels` passed in no_shortcut_train = random() < self.prob_no_shortcut_train
if flow_matching: if no_shortcut_train:
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
# in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this
step_sizes_log2 = torch.zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
signal_levels = torch.randint(0, self.max_steps, (batch, time), device = device)
else:
# now we follow eq (4)
step_sizes_log2 = torch.randint(1, self.num_step_sizes_log2, (batch,), device = device)
num_step_sizes = 2 ** step_sizes_log2
signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
# get the noise
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
times = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1') # times is from 0 to 1
orig_latents = latents times = rearrange(signal_levels.float() / self.max_steps, 'b t -> b t 1')
latents = noise.lerp(latents, times) # noise from 0 as noise to 1 as data
# allow for original velocity pred noised_latents = noise.lerp(latents, times)
# x-space as in paper is in else clause
if not self.pred_orig_latent:
pred_target = flow = latents - noise
else:
pred_target = latents
def get_prediction(noised_latents, signal_levels, step_sizes_log2):
# latents to spatial tokens # latents to spatial tokens
space_tokens = self.latents_to_spatial_tokens(latents) space_tokens = self.latents_to_spatial_tokens(noised_latents)
# pack to tokens # pack to tokens
# [signal + step size embed] [latent space tokens] [register] [actions / agent] # [signal + step size embed] [latent space tokens] [register] [actions / agent]
registers = repeat(self.register_tokens, 's d -> b t s d', b = latents.shape[0], t = latents.shape[1]) registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
agent_token = repeat(self.action_learned_embed, 'd -> b t d', b = latents.shape[0], t = latents.shape[1]) agent_token = repeat(self.action_learned_embed, 'd -> b t d', b = batch, t = time)
# determine signal + step size embed for their diffusion forcing + shortcut # determine signal + step size embed for their diffusion forcing + shortcut
if exists(signal_levels):
signal_embed = self.signal_levels_embed(signal_levels) signal_embed = self.signal_levels_embed(signal_levels)
step_size_embed = self.step_sizes_embed(step_sizes)
step_size_embed = self.step_size_embed(step_sizes_log2)
step_size_embed = repeat(step_size_embed, 'b ... -> b t ...', t = time)
flow_token = cat((signal_embed, step_size_embed), dim = -1) flow_token = cat((signal_embed, step_size_embed), dim = -1)
flow_token = rearrange(flow_token, 'b t d -> b t d') flow_token = rearrange(flow_token, 'b t d -> b t d')
else:
flow_token = registers[..., 0:0, :]
# pack to tokens for attending # pack to tokens for attending
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d') tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d')
@ -1197,11 +1219,86 @@ class DynamicsModel(Module):
pred = self.to_pred(pooled) pred = self.to_pred(pooled)
if not flow_matching:
return pred return pred
# forward the network
pred = get_prediction(noised_latents, signal_levels, step_sizes_log2)
if return_pred_only:
return pred
# determine the target for the loss
pred_target = None
is_x_space = self.pred_orig_latent
is_v_space_pred = not self.pred_orig_latent
maybe_shortcut_loss_weight = 1.
if no_shortcut_train:
# allow for original velocity pred
# x-space as in paper is in else clause
if is_v_space_pred:
pred_target = flow = latents - noise
else:
pred_target = latents
else:
# shortcut training - Frans et al. https://arxiv.org/abs/2410.12557
# basically a consistency loss where you ensure quantity of two half steps equals one step
# dreamer then makes it works for x-space with some math
get_prediction_no_grad = torch.no_grad()(get_prediction)
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
half_step_size = 2 ** step_sizes_log2_minus_one
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one)
# first derive b'
if is_v_space_pred:
first_step_pred_flow = first_step_pred
else:
first_times = signal_levels[..., None].float() / self.max_steps
first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times)
# take a half step
denoised_latent = noised_latents + first_step_pred_flow * (half_step_size[:, None, None] / self.max_steps)
# get second prediction for b''
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels + half_step_size[:, None], step_sizes_log2_minus_one)
if is_v_space_pred:
second_step_pred_flow = second_step_pred
else:
second_times = signal_levels[..., None].float() / self.max_steps
second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times)
# pred target is sg(b' + b'') / 2
pred_target = (first_step_pred_flow + second_step_pred_flow).detach() / 2
# need to convert x-space to v-space
if is_x_space:
pred = (pred - noised_latents) / (1. - first_times)
maybe_shortcut_loss_weight = (1. - first_times) ** 2
# mse loss
losses = F.mse_loss(pred, pred_target, reduction = 'none') losses = F.mse_loss(pred, pred_target, reduction = 'none')
losses = losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
# loss weighting with their ramp function
if exists(self.loss_weight_fn): if exists(self.loss_weight_fn):
loss_weight = self.loss_weight_fn(times) loss_weight = self.loss_weight_fn(times)
losses = losses * loss_weight losses = losses * loss_weight

View File

@ -5,10 +5,12 @@ import torch
@param('pred_orig_latent', (False, True)) @param('pred_orig_latent', (False, True))
@param('grouped_query_attn', (False, True)) @param('grouped_query_attn', (False, True))
@param('dynamics_with_video_input', (False, True)) @param('dynamics_with_video_input', (False, True))
@param('prob_no_shortcut_train', (None, 0., 1.))
def test_e2e( def test_e2e(
pred_orig_latent, pred_orig_latent,
grouped_query_attn, grouped_query_attn,
dynamics_with_video_input dynamics_with_video_input,
prob_no_shortcut_train
): ):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
@ -27,24 +29,24 @@ def test_e2e(
512, 512,
video_tokenizer = tokenizer, video_tokenizer = tokenizer,
dim_latent = 32, dim_latent = 32,
num_signal_levels = 500, max_steps = 64,
num_step_sizes = 32,
pred_orig_latent = pred_orig_latent, pred_orig_latent = pred_orig_latent,
attn_kwargs = dict( attn_kwargs = dict(
heads = heads, heads = heads,
query_heads = query_heads query_heads = query_heads
) ),
prob_no_shortcut_train = prob_no_shortcut_train
) )
signal_levels = torch.randint(0, 500, (2, 4)) signal_levels = torch.randint(0, 500, (2, 4))
step_sizes = torch.randint(0, 32, (2, 4)) step_sizes_log2 = torch.randint(1, 6, (2,))
if dynamics_with_video_input: if dynamics_with_video_input:
dynamics_input = dict(video = video) dynamics_input = dict(video = video)
else: else:
dynamics_input = dict(latents = latents) dynamics_input = dict(latents = latents)
flow_loss = dynamics(**dynamics_input, signal_levels = signal_levels, step_sizes = step_sizes) flow_loss = dynamics(**dynamics_input, signal_levels = signal_levels, step_sizes_log2 = step_sizes_log2)
assert flow_loss.numel() == 1 assert flow_loss.numel() == 1
def test_symexp_two_hot(): def test_symexp_two_hot():