make a first pass through the shortcut training logic (Frans et al from Berkeley) maintaining both v-space and x-space

This commit is contained in:
lucidrains 2025-10-05 11:17:36 -07:00
parent 971637673b
commit fe99efecba
2 changed files with 170 additions and 71 deletions

View File

@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
import math import math
from math import ceil from math import ceil, log2
from random import random
from collections import namedtuple from collections import namedtuple
from functools import partial from functools import partial
@ -70,6 +71,9 @@ def first(arr):
def divisible_by(num, den): def divisible_by(num, den):
return (num % den) == 0 return (num % den) == 0
def is_power_two(num):
return log2(num).is_integer()
# tensor helpers # tensor helpers
def pack_one(t, pattern): def pack_one(t, pattern):
@ -345,7 +349,9 @@ def apply_rotations(
# handle gqa for rotary # handle gqa for rotary
if rotations.ndim > 2 and heads < rotations.shape[0]: if rotations.ndim == 3 and rotations.shape[0] < heads:
rotary_heads = rotations.shape[0]
assert divisible_by(heads, rotary_heads) assert divisible_by(heads, rotary_heads)
groups = heads // rotary_heads groups = heads // rotary_heads
rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups) rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups)
@ -353,6 +359,8 @@ def apply_rotations(
x1, x2 = t.chunk(2, dim = -1) x1, x2 = t.chunk(2, dim = -1)
rotated_half_t = cat((-x2, x1), dim = -1) rotated_half_t = cat((-x2, x1), dim = -1)
# rotate in the positions
rotated = t * rotations.cos() + rotated_half_t * rotations.sin() rotated = t * rotations.cos() + rotated_half_t * rotations.sin()
return rotated.type(dtype) return rotated.type(dtype)
@ -971,8 +979,7 @@ class DynamicsModel(Module):
dim, dim,
dim_latent, dim_latent,
video_tokenizer: VideoTokenizer | None = None, video_tokenizer: VideoTokenizer | None = None,
num_signal_levels = 500, max_steps = 64, # K_max in paper
num_step_sizes = 32,
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction) num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
num_register_tokens = 8, # they claim register tokens led to better temporal consistency num_register_tokens = 8, # they claim register tokens led to better temporal consistency
depth = 4, depth = 4,
@ -985,7 +992,8 @@ class DynamicsModel(Module):
attn_softclamp_value = 50., attn_softclamp_value = 50.,
ff_kwargs: dict = dict(), ff_kwargs: dict = dict(),
loss_weight_fn: Callable = ramp_weight, loss_weight_fn: Callable = ramp_weight,
num_future_predictions = 8 # they do multi-token prediction of 8 steps forward num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
prob_no_shortcut_train = None # probability of no shortcut training, defaults to 1 / num_step_sizes
): ):
super().__init__() super().__init__()
@ -1007,14 +1015,18 @@ class DynamicsModel(Module):
assert divisible_by(dim, 2) assert divisible_by(dim, 2)
dim_half = dim // 2 dim_half = dim // 2
self.num_signal_levels = num_signal_levels assert is_power_two(max_steps), '`max_steps` must be a power of 2'
self.num_step_sizes = num_step_sizes self.max_steps = max_steps
self.num_step_sizes_log2 = int(log2(max_steps))
self.signal_levels_embed = nn.Embedding(num_signal_levels, dim_half) self.signal_levels_embed = nn.Embedding(max_steps, dim_half)
self.step_sizes_embed = nn.Embedding(num_step_sizes, dim_half) self.step_size_embed = nn.Embedding(self.num_step_sizes_log2, dim_half) # power of 2, so 1/1, 1/2, 1/4, 1/8 ... 1/Kmax
self.pred_orig_latent = pred_orig_latent self.prob_no_shortcut_train = default(prob_no_shortcut_train, self.num_step_sizes_log2 ** -1.)
# loss related
self.pred_orig_latent = pred_orig_latent # x-space or v-space
self.loss_weight_fn = loss_weight_fn self.loss_weight_fn = loss_weight_fn
# they sum all the actions into a single token # they sum all the actions into a single token
@ -1083,7 +1095,8 @@ class DynamicsModel(Module):
video = None, video = None,
latents = None, # (b t d) latents = None, # (b t d)
signal_levels = None, # (b t) signal_levels = None, # (b t)
step_sizes = None # (b t) step_sizes_log2 = None, # (b)
return_pred_only = False
): ):
# handle video or latents # handle video or latents
@ -1094,114 +1107,198 @@ class DynamicsModel(Module):
latents = self.video_tokenizer.tokenize(video) latents = self.video_tokenizer.tokenize(video)
time, device = latents.shape[1], latents.device batch, time, device = *latents.shape[:2], latents.device
# flow related # flow related
assert not (exists(signal_levels) ^ exists(step_sizes)) assert not (exists(signal_levels) ^ exists(step_sizes_log2))
flow_matching = exists(signal_levels) # if neither signal levels or step sizes passed in
# generate them randomly for training
# flow matching if `signal_levels` passed in no_shortcut_train = random() < self.prob_no_shortcut_train
if flow_matching: if no_shortcut_train:
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
# in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this
noise = torch.randn_like(latents) step_sizes_log2 = torch.zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
signal_levels = torch.randint(0, self.max_steps, (batch, time), device = device)
else:
times = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1') # now we follow eq (4)
orig_latents = latents step_sizes_log2 = torch.randint(1, self.num_step_sizes_log2, (batch,), device = device)
num_step_sizes = 2 ** step_sizes_log2
latents = noise.lerp(latents, times) signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
# allow for original velocity pred # get the noise
# x-space as in paper is in else clause
if not self.pred_orig_latent: noise = torch.randn_like(latents)
pred_target = flow = latents - noise
else:
pred_target = latents
# latents to spatial tokens # times is from 0 to 1
space_tokens = self.latents_to_spatial_tokens(latents) times = rearrange(signal_levels.float() / self.max_steps, 'b t -> b t 1')
# pack to tokens # noise from 0 as noise to 1 as data
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
registers = repeat(self.register_tokens, 's d -> b t s d', b = latents.shape[0], t = latents.shape[1]) noised_latents = noise.lerp(latents, times)
agent_token = repeat(self.action_learned_embed, 'd -> b t d', b = latents.shape[0], t = latents.shape[1]) def get_prediction(noised_latents, signal_levels, step_sizes_log2):
# latents to spatial tokens
# determine signal + step size embed for their diffusion forcing + shortcut space_tokens = self.latents_to_spatial_tokens(noised_latents)
# pack to tokens
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
agent_token = repeat(self.action_learned_embed, 'd -> b t d', b = batch, t = time)
# determine signal + step size embed for their diffusion forcing + shortcut
if exists(signal_levels):
signal_embed = self.signal_levels_embed(signal_levels) signal_embed = self.signal_levels_embed(signal_levels)
step_size_embed = self.step_sizes_embed(step_sizes)
step_size_embed = self.step_size_embed(step_sizes_log2)
step_size_embed = repeat(step_size_embed, 'b ... -> b t ...', t = time)
flow_token = cat((signal_embed, step_size_embed), dim = -1) flow_token = cat((signal_embed, step_size_embed), dim = -1)
flow_token = rearrange(flow_token, 'b t d -> b t d') flow_token = rearrange(flow_token, 'b t d -> b t d')
else: # pack to tokens for attending
flow_token = registers[..., 0:0, :]
# pack to tokens for attending tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d')
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d') # attend functions for space and time
# attend functions for space and time seq_len = tokens.shape[1]
seq_len = tokens.shape[1] use_flex = exists(flex_attention) and tokens.is_cuda
use_flex = exists(flex_attention) and tokens.is_cuda attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device)
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device) space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs) # rotary
# rotary rotary_pos_emb = self.time_rotary(time)
rotary_pos_emb = self.time_rotary(time) # attention
# attention for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time): tokens = pre_attn_rearrange(tokens)
tokens = pre_attn_rearrange(tokens) # when is a axial time attention block, should be causal
# when is a axial time attention block, should be causal attend_fn = time_attend if layer_is_time else space_attend
attend_fn = time_attend if layer_is_time else space_attend layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None # attention layer
# attention layer tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens
tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens tokens = post_attn_rearrange(tokens)
tokens = post_attn_rearrange(tokens) # feedforward layer
# feedforward layer tokens = ff(tokens) + tokens
tokens = ff(tokens) + tokens # unpack
# unpack flow_token, space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d')
flow_token, space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d') # pooling
# pooling pooled = reduce(space_tokens, 'b t s d -> b t d', 'mean')
pooled = reduce(space_tokens, 'b t s d -> b t d', 'mean') pred = self.to_pred(pooled)
pred = self.to_pred(pooled)
if not flow_matching:
return pred return pred
# forward the network
pred = get_prediction(noised_latents, signal_levels, step_sizes_log2)
if return_pred_only:
return pred
# determine the target for the loss
pred_target = None
is_x_space = self.pred_orig_latent
is_v_space_pred = not self.pred_orig_latent
maybe_shortcut_loss_weight = 1.
if no_shortcut_train:
# allow for original velocity pred
# x-space as in paper is in else clause
if is_v_space_pred:
pred_target = flow = latents - noise
else:
pred_target = latents
else:
# shortcut training - Frans et al. https://arxiv.org/abs/2410.12557
# basically a consistency loss where you ensure quantity of two half steps equals one step
# dreamer then makes it works for x-space with some math
get_prediction_no_grad = torch.no_grad()(get_prediction)
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
half_step_size = 2 ** step_sizes_log2_minus_one
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one)
# first derive b'
if is_v_space_pred:
first_step_pred_flow = first_step_pred
else:
first_times = signal_levels[..., None].float() / self.max_steps
first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times)
# take a half step
denoised_latent = noised_latents + first_step_pred_flow * (half_step_size[:, None, None] / self.max_steps)
# get second prediction for b''
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels + half_step_size[:, None], step_sizes_log2_minus_one)
if is_v_space_pred:
second_step_pred_flow = second_step_pred
else:
second_times = signal_levels[..., None].float() / self.max_steps
second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times)
# pred target is sg(b' + b'') / 2
pred_target = (first_step_pred_flow + second_step_pred_flow).detach() / 2
# need to convert x-space to v-space
if is_x_space:
pred = (pred - noised_latents) / (1. - first_times)
maybe_shortcut_loss_weight = (1. - first_times) ** 2
# mse loss
losses = F.mse_loss(pred, pred_target, reduction = 'none') losses = F.mse_loss(pred, pred_target, reduction = 'none')
losses = losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
# loss weighting with their ramp function
if exists(self.loss_weight_fn): if exists(self.loss_weight_fn):
loss_weight = self.loss_weight_fn(times) loss_weight = self.loss_weight_fn(times)
losses = losses * loss_weight losses = losses * loss_weight

View File

@ -5,10 +5,12 @@ import torch
@param('pred_orig_latent', (False, True)) @param('pred_orig_latent', (False, True))
@param('grouped_query_attn', (False, True)) @param('grouped_query_attn', (False, True))
@param('dynamics_with_video_input', (False, True)) @param('dynamics_with_video_input', (False, True))
@param('prob_no_shortcut_train', (None, 0., 1.))
def test_e2e( def test_e2e(
pred_orig_latent, pred_orig_latent,
grouped_query_attn, grouped_query_attn,
dynamics_with_video_input dynamics_with_video_input,
prob_no_shortcut_train
): ):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
@ -27,24 +29,24 @@ def test_e2e(
512, 512,
video_tokenizer = tokenizer, video_tokenizer = tokenizer,
dim_latent = 32, dim_latent = 32,
num_signal_levels = 500, max_steps = 64,
num_step_sizes = 32,
pred_orig_latent = pred_orig_latent, pred_orig_latent = pred_orig_latent,
attn_kwargs = dict( attn_kwargs = dict(
heads = heads, heads = heads,
query_heads = query_heads query_heads = query_heads
) ),
prob_no_shortcut_train = prob_no_shortcut_train
) )
signal_levels = torch.randint(0, 500, (2, 4)) signal_levels = torch.randint(0, 500, (2, 4))
step_sizes = torch.randint(0, 32, (2, 4)) step_sizes_log2 = torch.randint(1, 6, (2,))
if dynamics_with_video_input: if dynamics_with_video_input:
dynamics_input = dict(video = video) dynamics_input = dict(video = video)
else: else:
dynamics_input = dict(latents = latents) dynamics_input = dict(latents = latents)
flow_loss = dynamics(**dynamics_input, signal_levels = signal_levels, step_sizes = step_sizes) flow_loss = dynamics(**dynamics_input, signal_levels = signal_levels, step_sizes_log2 = step_sizes_log2)
assert flow_loss.numel() == 1 assert flow_loss.numel() == 1
def test_symexp_two_hot(): def test_symexp_two_hot():