make a first pass through the shortcut training logic (Frans et al from Berkeley) maintaining both v-space and x-space
This commit is contained in:
parent
971637673b
commit
fe99efecba
@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from math import ceil
|
||||
from math import ceil, log2
|
||||
from random import random
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
|
||||
@ -70,6 +71,9 @@ def first(arr):
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
def is_power_two(num):
|
||||
return log2(num).is_integer()
|
||||
|
||||
# tensor helpers
|
||||
|
||||
def pack_one(t, pattern):
|
||||
@ -345,7 +349,9 @@ def apply_rotations(
|
||||
|
||||
# handle gqa for rotary
|
||||
|
||||
if rotations.ndim > 2 and heads < rotations.shape[0]:
|
||||
if rotations.ndim == 3 and rotations.shape[0] < heads:
|
||||
rotary_heads = rotations.shape[0]
|
||||
|
||||
assert divisible_by(heads, rotary_heads)
|
||||
groups = heads // rotary_heads
|
||||
rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups)
|
||||
@ -353,6 +359,8 @@ def apply_rotations(
|
||||
x1, x2 = t.chunk(2, dim = -1)
|
||||
rotated_half_t = cat((-x2, x1), dim = -1)
|
||||
|
||||
# rotate in the positions
|
||||
|
||||
rotated = t * rotations.cos() + rotated_half_t * rotations.sin()
|
||||
return rotated.type(dtype)
|
||||
|
||||
@ -971,8 +979,7 @@ class DynamicsModel(Module):
|
||||
dim,
|
||||
dim_latent,
|
||||
video_tokenizer: VideoTokenizer | None = None,
|
||||
num_signal_levels = 500,
|
||||
num_step_sizes = 32,
|
||||
max_steps = 64, # K_max in paper
|
||||
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
|
||||
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
||||
depth = 4,
|
||||
@ -985,7 +992,8 @@ class DynamicsModel(Module):
|
||||
attn_softclamp_value = 50.,
|
||||
ff_kwargs: dict = dict(),
|
||||
loss_weight_fn: Callable = ramp_weight,
|
||||
num_future_predictions = 8 # they do multi-token prediction of 8 steps forward
|
||||
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
||||
prob_no_shortcut_train = None # probability of no shortcut training, defaults to 1 / num_step_sizes
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -1007,14 +1015,18 @@ class DynamicsModel(Module):
|
||||
assert divisible_by(dim, 2)
|
||||
dim_half = dim // 2
|
||||
|
||||
self.num_signal_levels = num_signal_levels
|
||||
self.num_step_sizes = num_step_sizes
|
||||
assert is_power_two(max_steps), '`max_steps` must be a power of 2'
|
||||
self.max_steps = max_steps
|
||||
self.num_step_sizes_log2 = int(log2(max_steps))
|
||||
|
||||
self.signal_levels_embed = nn.Embedding(num_signal_levels, dim_half)
|
||||
self.step_sizes_embed = nn.Embedding(num_step_sizes, dim_half)
|
||||
self.signal_levels_embed = nn.Embedding(max_steps, dim_half)
|
||||
self.step_size_embed = nn.Embedding(self.num_step_sizes_log2, dim_half) # power of 2, so 1/1, 1/2, 1/4, 1/8 ... 1/Kmax
|
||||
|
||||
self.pred_orig_latent = pred_orig_latent
|
||||
self.prob_no_shortcut_train = default(prob_no_shortcut_train, self.num_step_sizes_log2 ** -1.)
|
||||
|
||||
# loss related
|
||||
|
||||
self.pred_orig_latent = pred_orig_latent # x-space or v-space
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
|
||||
# they sum all the actions into a single token
|
||||
@ -1083,7 +1095,8 @@ class DynamicsModel(Module):
|
||||
video = None,
|
||||
latents = None, # (b t d)
|
||||
signal_levels = None, # (b t)
|
||||
step_sizes = None # (b t)
|
||||
step_sizes_log2 = None, # (b)
|
||||
return_pred_only = False
|
||||
):
|
||||
# handle video or latents
|
||||
|
||||
@ -1094,114 +1107,198 @@ class DynamicsModel(Module):
|
||||
|
||||
latents = self.video_tokenizer.tokenize(video)
|
||||
|
||||
time, device = latents.shape[1], latents.device
|
||||
batch, time, device = *latents.shape[:2], latents.device
|
||||
|
||||
# flow related
|
||||
|
||||
assert not (exists(signal_levels) ^ exists(step_sizes))
|
||||
assert not (exists(signal_levels) ^ exists(step_sizes_log2))
|
||||
|
||||
flow_matching = exists(signal_levels)
|
||||
# if neither signal levels or step sizes passed in
|
||||
# generate them randomly for training
|
||||
|
||||
# flow matching if `signal_levels` passed in
|
||||
no_shortcut_train = random() < self.prob_no_shortcut_train
|
||||
|
||||
if flow_matching:
|
||||
if no_shortcut_train:
|
||||
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
|
||||
# in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this
|
||||
|
||||
noise = torch.randn_like(latents)
|
||||
step_sizes_log2 = torch.zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
|
||||
signal_levels = torch.randint(0, self.max_steps, (batch, time), device = device)
|
||||
else:
|
||||
|
||||
times = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1')
|
||||
# now we follow eq (4)
|
||||
|
||||
orig_latents = latents
|
||||
step_sizes_log2 = torch.randint(1, self.num_step_sizes_log2, (batch,), device = device)
|
||||
num_step_sizes = 2 ** step_sizes_log2
|
||||
|
||||
latents = noise.lerp(latents, times)
|
||||
signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
|
||||
|
||||
# allow for original velocity pred
|
||||
# x-space as in paper is in else clause
|
||||
# get the noise
|
||||
|
||||
if not self.pred_orig_latent:
|
||||
pred_target = flow = latents - noise
|
||||
else:
|
||||
pred_target = latents
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
# latents to spatial tokens
|
||||
# times is from 0 to 1
|
||||
|
||||
space_tokens = self.latents_to_spatial_tokens(latents)
|
||||
times = rearrange(signal_levels.float() / self.max_steps, 'b t -> b t 1')
|
||||
|
||||
# pack to tokens
|
||||
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
|
||||
# noise from 0 as noise to 1 as data
|
||||
|
||||
registers = repeat(self.register_tokens, 's d -> b t s d', b = latents.shape[0], t = latents.shape[1])
|
||||
noised_latents = noise.lerp(latents, times)
|
||||
|
||||
agent_token = repeat(self.action_learned_embed, 'd -> b t d', b = latents.shape[0], t = latents.shape[1])
|
||||
def get_prediction(noised_latents, signal_levels, step_sizes_log2):
|
||||
# latents to spatial tokens
|
||||
|
||||
# determine signal + step size embed for their diffusion forcing + shortcut
|
||||
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
||||
|
||||
# pack to tokens
|
||||
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
|
||||
|
||||
registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
|
||||
|
||||
agent_token = repeat(self.action_learned_embed, 'd -> b t d', b = batch, t = time)
|
||||
|
||||
# determine signal + step size embed for their diffusion forcing + shortcut
|
||||
|
||||
if exists(signal_levels):
|
||||
signal_embed = self.signal_levels_embed(signal_levels)
|
||||
step_size_embed = self.step_sizes_embed(step_sizes)
|
||||
|
||||
step_size_embed = self.step_size_embed(step_sizes_log2)
|
||||
step_size_embed = repeat(step_size_embed, 'b ... -> b t ...', t = time)
|
||||
|
||||
flow_token = cat((signal_embed, step_size_embed), dim = -1)
|
||||
flow_token = rearrange(flow_token, 'b t d -> b t d')
|
||||
|
||||
else:
|
||||
flow_token = registers[..., 0:0, :]
|
||||
# pack to tokens for attending
|
||||
|
||||
# pack to tokens for attending
|
||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d')
|
||||
|
||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d')
|
||||
# attend functions for space and time
|
||||
|
||||
# attend functions for space and time
|
||||
seq_len = tokens.shape[1]
|
||||
|
||||
seq_len = tokens.shape[1]
|
||||
use_flex = exists(flex_attention) and tokens.is_cuda
|
||||
|
||||
use_flex = exists(flex_attention) and tokens.is_cuda
|
||||
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device)
|
||||
|
||||
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device)
|
||||
space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
||||
|
||||
space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
||||
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
||||
|
||||
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
||||
# rotary
|
||||
|
||||
# rotary
|
||||
rotary_pos_emb = self.time_rotary(time)
|
||||
|
||||
rotary_pos_emb = self.time_rotary(time)
|
||||
# attention
|
||||
|
||||
# attention
|
||||
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
||||
|
||||
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
||||
tokens = pre_attn_rearrange(tokens)
|
||||
|
||||
tokens = pre_attn_rearrange(tokens)
|
||||
# when is a axial time attention block, should be causal
|
||||
|
||||
# when is a axial time attention block, should be causal
|
||||
attend_fn = time_attend if layer_is_time else space_attend
|
||||
|
||||
attend_fn = time_attend if layer_is_time else space_attend
|
||||
layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
|
||||
|
||||
layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
|
||||
# attention layer
|
||||
|
||||
# attention layer
|
||||
tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens
|
||||
|
||||
tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens
|
||||
tokens = post_attn_rearrange(tokens)
|
||||
|
||||
tokens = post_attn_rearrange(tokens)
|
||||
# feedforward layer
|
||||
|
||||
# feedforward layer
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
tokens = ff(tokens) + tokens
|
||||
# unpack
|
||||
|
||||
# unpack
|
||||
flow_token, space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||
|
||||
flow_token, space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||
# pooling
|
||||
|
||||
# pooling
|
||||
pooled = reduce(space_tokens, 'b t s d -> b t d', 'mean')
|
||||
|
||||
pooled = reduce(space_tokens, 'b t s d -> b t d', 'mean')
|
||||
pred = self.to_pred(pooled)
|
||||
|
||||
pred = self.to_pred(pooled)
|
||||
|
||||
if not flow_matching:
|
||||
return pred
|
||||
|
||||
# forward the network
|
||||
|
||||
pred = get_prediction(noised_latents, signal_levels, step_sizes_log2)
|
||||
|
||||
if return_pred_only:
|
||||
return pred
|
||||
|
||||
# determine the target for the loss
|
||||
|
||||
pred_target = None
|
||||
|
||||
is_x_space = self.pred_orig_latent
|
||||
is_v_space_pred = not self.pred_orig_latent
|
||||
|
||||
maybe_shortcut_loss_weight = 1.
|
||||
|
||||
if no_shortcut_train:
|
||||
|
||||
# allow for original velocity pred
|
||||
# x-space as in paper is in else clause
|
||||
|
||||
if is_v_space_pred:
|
||||
pred_target = flow = latents - noise
|
||||
else:
|
||||
pred_target = latents
|
||||
else:
|
||||
# shortcut training - Frans et al. https://arxiv.org/abs/2410.12557
|
||||
|
||||
# basically a consistency loss where you ensure quantity of two half steps equals one step
|
||||
# dreamer then makes it works for x-space with some math
|
||||
|
||||
get_prediction_no_grad = torch.no_grad()(get_prediction)
|
||||
|
||||
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
||||
half_step_size = 2 ** step_sizes_log2_minus_one
|
||||
|
||||
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one)
|
||||
|
||||
# first derive b'
|
||||
|
||||
if is_v_space_pred:
|
||||
first_step_pred_flow = first_step_pred
|
||||
else:
|
||||
first_times = signal_levels[..., None].float() / self.max_steps
|
||||
first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times)
|
||||
|
||||
# take a half step
|
||||
|
||||
denoised_latent = noised_latents + first_step_pred_flow * (half_step_size[:, None, None] / self.max_steps)
|
||||
|
||||
# get second prediction for b''
|
||||
|
||||
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels + half_step_size[:, None], step_sizes_log2_minus_one)
|
||||
|
||||
if is_v_space_pred:
|
||||
second_step_pred_flow = second_step_pred
|
||||
else:
|
||||
second_times = signal_levels[..., None].float() / self.max_steps
|
||||
second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times)
|
||||
|
||||
# pred target is sg(b' + b'') / 2
|
||||
|
||||
pred_target = (first_step_pred_flow + second_step_pred_flow).detach() / 2
|
||||
|
||||
# need to convert x-space to v-space
|
||||
|
||||
if is_x_space:
|
||||
pred = (pred - noised_latents) / (1. - first_times)
|
||||
maybe_shortcut_loss_weight = (1. - first_times) ** 2
|
||||
|
||||
# mse loss
|
||||
|
||||
losses = F.mse_loss(pred, pred_target, reduction = 'none')
|
||||
|
||||
losses = losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
|
||||
|
||||
# loss weighting with their ramp function
|
||||
|
||||
if exists(self.loss_weight_fn):
|
||||
loss_weight = self.loss_weight_fn(times)
|
||||
losses = losses * loss_weight
|
||||
|
||||
@ -5,10 +5,12 @@ import torch
|
||||
@param('pred_orig_latent', (False, True))
|
||||
@param('grouped_query_attn', (False, True))
|
||||
@param('dynamics_with_video_input', (False, True))
|
||||
@param('prob_no_shortcut_train', (None, 0., 1.))
|
||||
def test_e2e(
|
||||
pred_orig_latent,
|
||||
grouped_query_attn,
|
||||
dynamics_with_video_input
|
||||
dynamics_with_video_input,
|
||||
prob_no_shortcut_train
|
||||
):
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
||||
|
||||
@ -27,24 +29,24 @@ def test_e2e(
|
||||
512,
|
||||
video_tokenizer = tokenizer,
|
||||
dim_latent = 32,
|
||||
num_signal_levels = 500,
|
||||
num_step_sizes = 32,
|
||||
max_steps = 64,
|
||||
pred_orig_latent = pred_orig_latent,
|
||||
attn_kwargs = dict(
|
||||
heads = heads,
|
||||
query_heads = query_heads
|
||||
)
|
||||
),
|
||||
prob_no_shortcut_train = prob_no_shortcut_train
|
||||
)
|
||||
|
||||
signal_levels = torch.randint(0, 500, (2, 4))
|
||||
step_sizes = torch.randint(0, 32, (2, 4))
|
||||
step_sizes_log2 = torch.randint(1, 6, (2,))
|
||||
|
||||
if dynamics_with_video_input:
|
||||
dynamics_input = dict(video = video)
|
||||
else:
|
||||
dynamics_input = dict(latents = latents)
|
||||
|
||||
flow_loss = dynamics(**dynamics_input, signal_levels = signal_levels, step_sizes = step_sizes)
|
||||
flow_loss = dynamics(**dynamics_input, signal_levels = signal_levels, step_sizes_log2 = step_sizes_log2)
|
||||
assert flow_loss.numel() == 1
|
||||
|
||||
def test_symexp_two_hot():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user