make a first pass through the shortcut training logic (Frans et al from Berkeley) maintaining both v-space and x-space
This commit is contained in:
parent
971637673b
commit
fe99efecba
@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from math import ceil
|
from math import ceil, log2
|
||||||
|
from random import random
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
@ -70,6 +71,9 @@ def first(arr):
|
|||||||
def divisible_by(num, den):
|
def divisible_by(num, den):
|
||||||
return (num % den) == 0
|
return (num % den) == 0
|
||||||
|
|
||||||
|
def is_power_two(num):
|
||||||
|
return log2(num).is_integer()
|
||||||
|
|
||||||
# tensor helpers
|
# tensor helpers
|
||||||
|
|
||||||
def pack_one(t, pattern):
|
def pack_one(t, pattern):
|
||||||
@ -345,7 +349,9 @@ def apply_rotations(
|
|||||||
|
|
||||||
# handle gqa for rotary
|
# handle gqa for rotary
|
||||||
|
|
||||||
if rotations.ndim > 2 and heads < rotations.shape[0]:
|
if rotations.ndim == 3 and rotations.shape[0] < heads:
|
||||||
|
rotary_heads = rotations.shape[0]
|
||||||
|
|
||||||
assert divisible_by(heads, rotary_heads)
|
assert divisible_by(heads, rotary_heads)
|
||||||
groups = heads // rotary_heads
|
groups = heads // rotary_heads
|
||||||
rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups)
|
rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups)
|
||||||
@ -353,6 +359,8 @@ def apply_rotations(
|
|||||||
x1, x2 = t.chunk(2, dim = -1)
|
x1, x2 = t.chunk(2, dim = -1)
|
||||||
rotated_half_t = cat((-x2, x1), dim = -1)
|
rotated_half_t = cat((-x2, x1), dim = -1)
|
||||||
|
|
||||||
|
# rotate in the positions
|
||||||
|
|
||||||
rotated = t * rotations.cos() + rotated_half_t * rotations.sin()
|
rotated = t * rotations.cos() + rotated_half_t * rotations.sin()
|
||||||
return rotated.type(dtype)
|
return rotated.type(dtype)
|
||||||
|
|
||||||
@ -971,8 +979,7 @@ class DynamicsModel(Module):
|
|||||||
dim,
|
dim,
|
||||||
dim_latent,
|
dim_latent,
|
||||||
video_tokenizer: VideoTokenizer | None = None,
|
video_tokenizer: VideoTokenizer | None = None,
|
||||||
num_signal_levels = 500,
|
max_steps = 64, # K_max in paper
|
||||||
num_step_sizes = 32,
|
|
||||||
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
|
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
|
||||||
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
||||||
depth = 4,
|
depth = 4,
|
||||||
@ -985,7 +992,8 @@ class DynamicsModel(Module):
|
|||||||
attn_softclamp_value = 50.,
|
attn_softclamp_value = 50.,
|
||||||
ff_kwargs: dict = dict(),
|
ff_kwargs: dict = dict(),
|
||||||
loss_weight_fn: Callable = ramp_weight,
|
loss_weight_fn: Callable = ramp_weight,
|
||||||
num_future_predictions = 8 # they do multi-token prediction of 8 steps forward
|
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
||||||
|
prob_no_shortcut_train = None # probability of no shortcut training, defaults to 1 / num_step_sizes
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -1007,14 +1015,18 @@ class DynamicsModel(Module):
|
|||||||
assert divisible_by(dim, 2)
|
assert divisible_by(dim, 2)
|
||||||
dim_half = dim // 2
|
dim_half = dim // 2
|
||||||
|
|
||||||
self.num_signal_levels = num_signal_levels
|
assert is_power_two(max_steps), '`max_steps` must be a power of 2'
|
||||||
self.num_step_sizes = num_step_sizes
|
self.max_steps = max_steps
|
||||||
|
self.num_step_sizes_log2 = int(log2(max_steps))
|
||||||
|
|
||||||
self.signal_levels_embed = nn.Embedding(num_signal_levels, dim_half)
|
self.signal_levels_embed = nn.Embedding(max_steps, dim_half)
|
||||||
self.step_sizes_embed = nn.Embedding(num_step_sizes, dim_half)
|
self.step_size_embed = nn.Embedding(self.num_step_sizes_log2, dim_half) # power of 2, so 1/1, 1/2, 1/4, 1/8 ... 1/Kmax
|
||||||
|
|
||||||
self.pred_orig_latent = pred_orig_latent
|
self.prob_no_shortcut_train = default(prob_no_shortcut_train, self.num_step_sizes_log2 ** -1.)
|
||||||
|
|
||||||
|
# loss related
|
||||||
|
|
||||||
|
self.pred_orig_latent = pred_orig_latent # x-space or v-space
|
||||||
self.loss_weight_fn = loss_weight_fn
|
self.loss_weight_fn = loss_weight_fn
|
||||||
|
|
||||||
# they sum all the actions into a single token
|
# they sum all the actions into a single token
|
||||||
@ -1083,7 +1095,8 @@ class DynamicsModel(Module):
|
|||||||
video = None,
|
video = None,
|
||||||
latents = None, # (b t d)
|
latents = None, # (b t d)
|
||||||
signal_levels = None, # (b t)
|
signal_levels = None, # (b t)
|
||||||
step_sizes = None # (b t)
|
step_sizes_log2 = None, # (b)
|
||||||
|
return_pred_only = False
|
||||||
):
|
):
|
||||||
# handle video or latents
|
# handle video or latents
|
||||||
|
|
||||||
@ -1094,57 +1107,66 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
latents = self.video_tokenizer.tokenize(video)
|
latents = self.video_tokenizer.tokenize(video)
|
||||||
|
|
||||||
time, device = latents.shape[1], latents.device
|
batch, time, device = *latents.shape[:2], latents.device
|
||||||
|
|
||||||
# flow related
|
# flow related
|
||||||
|
|
||||||
assert not (exists(signal_levels) ^ exists(step_sizes))
|
assert not (exists(signal_levels) ^ exists(step_sizes_log2))
|
||||||
|
|
||||||
flow_matching = exists(signal_levels)
|
# if neither signal levels or step sizes passed in
|
||||||
|
# generate them randomly for training
|
||||||
|
|
||||||
# flow matching if `signal_levels` passed in
|
no_shortcut_train = random() < self.prob_no_shortcut_train
|
||||||
|
|
||||||
if flow_matching:
|
if no_shortcut_train:
|
||||||
|
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
|
||||||
|
# in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this
|
||||||
|
|
||||||
|
step_sizes_log2 = torch.zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
|
||||||
|
signal_levels = torch.randint(0, self.max_steps, (batch, time), device = device)
|
||||||
|
else:
|
||||||
|
|
||||||
|
# now we follow eq (4)
|
||||||
|
|
||||||
|
step_sizes_log2 = torch.randint(1, self.num_step_sizes_log2, (batch,), device = device)
|
||||||
|
num_step_sizes = 2 ** step_sizes_log2
|
||||||
|
|
||||||
|
signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
|
||||||
|
|
||||||
|
# get the noise
|
||||||
|
|
||||||
noise = torch.randn_like(latents)
|
noise = torch.randn_like(latents)
|
||||||
|
|
||||||
times = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1')
|
# times is from 0 to 1
|
||||||
|
|
||||||
orig_latents = latents
|
times = rearrange(signal_levels.float() / self.max_steps, 'b t -> b t 1')
|
||||||
|
|
||||||
latents = noise.lerp(latents, times)
|
# noise from 0 as noise to 1 as data
|
||||||
|
|
||||||
# allow for original velocity pred
|
noised_latents = noise.lerp(latents, times)
|
||||||
# x-space as in paper is in else clause
|
|
||||||
|
|
||||||
if not self.pred_orig_latent:
|
|
||||||
pred_target = flow = latents - noise
|
|
||||||
else:
|
|
||||||
pred_target = latents
|
|
||||||
|
|
||||||
|
def get_prediction(noised_latents, signal_levels, step_sizes_log2):
|
||||||
# latents to spatial tokens
|
# latents to spatial tokens
|
||||||
|
|
||||||
space_tokens = self.latents_to_spatial_tokens(latents)
|
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
||||||
|
|
||||||
# pack to tokens
|
# pack to tokens
|
||||||
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
|
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
|
||||||
|
|
||||||
registers = repeat(self.register_tokens, 's d -> b t s d', b = latents.shape[0], t = latents.shape[1])
|
registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
|
||||||
|
|
||||||
agent_token = repeat(self.action_learned_embed, 'd -> b t d', b = latents.shape[0], t = latents.shape[1])
|
agent_token = repeat(self.action_learned_embed, 'd -> b t d', b = batch, t = time)
|
||||||
|
|
||||||
# determine signal + step size embed for their diffusion forcing + shortcut
|
# determine signal + step size embed for their diffusion forcing + shortcut
|
||||||
|
|
||||||
if exists(signal_levels):
|
|
||||||
signal_embed = self.signal_levels_embed(signal_levels)
|
signal_embed = self.signal_levels_embed(signal_levels)
|
||||||
step_size_embed = self.step_sizes_embed(step_sizes)
|
|
||||||
|
step_size_embed = self.step_size_embed(step_sizes_log2)
|
||||||
|
step_size_embed = repeat(step_size_embed, 'b ... -> b t ...', t = time)
|
||||||
|
|
||||||
flow_token = cat((signal_embed, step_size_embed), dim = -1)
|
flow_token = cat((signal_embed, step_size_embed), dim = -1)
|
||||||
flow_token = rearrange(flow_token, 'b t d -> b t d')
|
flow_token = rearrange(flow_token, 'b t d -> b t d')
|
||||||
|
|
||||||
else:
|
|
||||||
flow_token = registers[..., 0:0, :]
|
|
||||||
|
|
||||||
# pack to tokens for attending
|
# pack to tokens for attending
|
||||||
|
|
||||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d')
|
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d')
|
||||||
@ -1197,11 +1219,86 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
pred = self.to_pred(pooled)
|
pred = self.to_pred(pooled)
|
||||||
|
|
||||||
if not flow_matching:
|
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
|
# forward the network
|
||||||
|
|
||||||
|
pred = get_prediction(noised_latents, signal_levels, step_sizes_log2)
|
||||||
|
|
||||||
|
if return_pred_only:
|
||||||
|
return pred
|
||||||
|
|
||||||
|
# determine the target for the loss
|
||||||
|
|
||||||
|
pred_target = None
|
||||||
|
|
||||||
|
is_x_space = self.pred_orig_latent
|
||||||
|
is_v_space_pred = not self.pred_orig_latent
|
||||||
|
|
||||||
|
maybe_shortcut_loss_weight = 1.
|
||||||
|
|
||||||
|
if no_shortcut_train:
|
||||||
|
|
||||||
|
# allow for original velocity pred
|
||||||
|
# x-space as in paper is in else clause
|
||||||
|
|
||||||
|
if is_v_space_pred:
|
||||||
|
pred_target = flow = latents - noise
|
||||||
|
else:
|
||||||
|
pred_target = latents
|
||||||
|
else:
|
||||||
|
# shortcut training - Frans et al. https://arxiv.org/abs/2410.12557
|
||||||
|
|
||||||
|
# basically a consistency loss where you ensure quantity of two half steps equals one step
|
||||||
|
# dreamer then makes it works for x-space with some math
|
||||||
|
|
||||||
|
get_prediction_no_grad = torch.no_grad()(get_prediction)
|
||||||
|
|
||||||
|
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
||||||
|
half_step_size = 2 ** step_sizes_log2_minus_one
|
||||||
|
|
||||||
|
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one)
|
||||||
|
|
||||||
|
# first derive b'
|
||||||
|
|
||||||
|
if is_v_space_pred:
|
||||||
|
first_step_pred_flow = first_step_pred
|
||||||
|
else:
|
||||||
|
first_times = signal_levels[..., None].float() / self.max_steps
|
||||||
|
first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times)
|
||||||
|
|
||||||
|
# take a half step
|
||||||
|
|
||||||
|
denoised_latent = noised_latents + first_step_pred_flow * (half_step_size[:, None, None] / self.max_steps)
|
||||||
|
|
||||||
|
# get second prediction for b''
|
||||||
|
|
||||||
|
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels + half_step_size[:, None], step_sizes_log2_minus_one)
|
||||||
|
|
||||||
|
if is_v_space_pred:
|
||||||
|
second_step_pred_flow = second_step_pred
|
||||||
|
else:
|
||||||
|
second_times = signal_levels[..., None].float() / self.max_steps
|
||||||
|
second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times)
|
||||||
|
|
||||||
|
# pred target is sg(b' + b'') / 2
|
||||||
|
|
||||||
|
pred_target = (first_step_pred_flow + second_step_pred_flow).detach() / 2
|
||||||
|
|
||||||
|
# need to convert x-space to v-space
|
||||||
|
|
||||||
|
if is_x_space:
|
||||||
|
pred = (pred - noised_latents) / (1. - first_times)
|
||||||
|
maybe_shortcut_loss_weight = (1. - first_times) ** 2
|
||||||
|
|
||||||
|
# mse loss
|
||||||
|
|
||||||
losses = F.mse_loss(pred, pred_target, reduction = 'none')
|
losses = F.mse_loss(pred, pred_target, reduction = 'none')
|
||||||
|
|
||||||
|
losses = losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
|
||||||
|
|
||||||
|
# loss weighting with their ramp function
|
||||||
|
|
||||||
if exists(self.loss_weight_fn):
|
if exists(self.loss_weight_fn):
|
||||||
loss_weight = self.loss_weight_fn(times)
|
loss_weight = self.loss_weight_fn(times)
|
||||||
losses = losses * loss_weight
|
losses = losses * loss_weight
|
||||||
|
|||||||
@ -5,10 +5,12 @@ import torch
|
|||||||
@param('pred_orig_latent', (False, True))
|
@param('pred_orig_latent', (False, True))
|
||||||
@param('grouped_query_attn', (False, True))
|
@param('grouped_query_attn', (False, True))
|
||||||
@param('dynamics_with_video_input', (False, True))
|
@param('dynamics_with_video_input', (False, True))
|
||||||
|
@param('prob_no_shortcut_train', (None, 0., 1.))
|
||||||
def test_e2e(
|
def test_e2e(
|
||||||
pred_orig_latent,
|
pred_orig_latent,
|
||||||
grouped_query_attn,
|
grouped_query_attn,
|
||||||
dynamics_with_video_input
|
dynamics_with_video_input,
|
||||||
|
prob_no_shortcut_train
|
||||||
):
|
):
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
||||||
|
|
||||||
@ -27,24 +29,24 @@ def test_e2e(
|
|||||||
512,
|
512,
|
||||||
video_tokenizer = tokenizer,
|
video_tokenizer = tokenizer,
|
||||||
dim_latent = 32,
|
dim_latent = 32,
|
||||||
num_signal_levels = 500,
|
max_steps = 64,
|
||||||
num_step_sizes = 32,
|
|
||||||
pred_orig_latent = pred_orig_latent,
|
pred_orig_latent = pred_orig_latent,
|
||||||
attn_kwargs = dict(
|
attn_kwargs = dict(
|
||||||
heads = heads,
|
heads = heads,
|
||||||
query_heads = query_heads
|
query_heads = query_heads
|
||||||
)
|
),
|
||||||
|
prob_no_shortcut_train = prob_no_shortcut_train
|
||||||
)
|
)
|
||||||
|
|
||||||
signal_levels = torch.randint(0, 500, (2, 4))
|
signal_levels = torch.randint(0, 500, (2, 4))
|
||||||
step_sizes = torch.randint(0, 32, (2, 4))
|
step_sizes_log2 = torch.randint(1, 6, (2,))
|
||||||
|
|
||||||
if dynamics_with_video_input:
|
if dynamics_with_video_input:
|
||||||
dynamics_input = dict(video = video)
|
dynamics_input = dict(video = video)
|
||||||
else:
|
else:
|
||||||
dynamics_input = dict(latents = latents)
|
dynamics_input = dict(latents = latents)
|
||||||
|
|
||||||
flow_loss = dynamics(**dynamics_input, signal_levels = signal_levels, step_sizes = step_sizes)
|
flow_loss = dynamics(**dynamics_input, signal_levels = signal_levels, step_sizes_log2 = step_sizes_log2)
|
||||||
assert flow_loss.numel() == 1
|
assert flow_loss.numel() == 1
|
||||||
|
|
||||||
def test_symexp_two_hot():
|
def test_symexp_two_hot():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user