3369 lines
109 KiB
Python
3369 lines
109 KiB
Python
from __future__ import annotations
|
|
|
|
import math
|
|
from math import ceil, log2
|
|
from random import random
|
|
from contextlib import nullcontext
|
|
from collections import namedtuple
|
|
from functools import partial, wraps
|
|
from dataclasses import dataclass, asdict
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.nested import nested_tensor
|
|
from torch.distributions import Normal
|
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
|
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
|
import torchvision
|
|
from torchvision.models import VGG16_Weights
|
|
|
|
from torch.optim import Optimizer
|
|
from adam_atan2_pytorch import MuonAdamAtan2
|
|
|
|
from x_mlps_pytorch.ensemble import Ensemble
|
|
from x_mlps_pytorch.normed_mlp import create_mlp
|
|
|
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
|
|
from assoc_scan import AssocScan
|
|
|
|
# ein related
|
|
|
|
# b - batch
|
|
# n - sequence
|
|
# h - attention heads
|
|
# d - feature dimension
|
|
# f - frequencies (rotary)
|
|
# l - logit / predicted bins
|
|
# y - layers of transformer
|
|
# p - positions (3 for spacetime in this work)
|
|
# t - time
|
|
# na - action dimension (number of discrete and continuous actions)
|
|
# g - groups of query heads to key heads (gqa)
|
|
# vc - video channels
|
|
# vh, vw - video height and width
|
|
# mtp - multi token prediction length
|
|
# v - video viewpoints
|
|
|
|
import einx
|
|
from einx import add, multiply
|
|
from einops import einsum, rearrange, repeat, reduce, pack, unpack
|
|
from einops.layers.torch import Rearrange, Reduce
|
|
|
|
# flex attention - but will make sure it works if it is not available
|
|
# may also end up crafting own custom flash attention kernel for this work
|
|
|
|
flex_attention = None
|
|
|
|
try:
|
|
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
|
if torch.cuda.is_available():
|
|
flex_attention = torch.compile(flex_attention)
|
|
except ImportError:
|
|
pass
|
|
|
|
# constants
|
|
|
|
LinearNoBias = partial(Linear, bias = False)
|
|
|
|
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
|
|
|
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
|
|
|
@dataclass
|
|
class Experience:
|
|
latents: Tensor
|
|
video: Tensor | None = None
|
|
proprio: Tensor | None = None
|
|
agent_embed: Tensor | None = None,
|
|
rewards: Tensor | None = None
|
|
actions: tuple[Tensor, Tensor] | None = None
|
|
log_probs: tuple[Tensor, Tensor] | None = None
|
|
values: Tensor | None = None
|
|
step_size: int | None = None
|
|
lens: Tensor | None = None
|
|
is_truncated: Tensor | None = None
|
|
agent_index: int = 0
|
|
is_from_world_model: bool = True
|
|
|
|
def combine_experiences(
|
|
exps: list[Experiences]
|
|
) -> Experience:
|
|
|
|
assert len(exps) > 0
|
|
|
|
# set lens if not there
|
|
|
|
for exp in exps:
|
|
latents = exp.latents
|
|
batch, time, device = *latents.shape[:2], latents.device
|
|
|
|
if not exists(exp.lens):
|
|
exp.lens = full((batch,), time, device = device)
|
|
|
|
if not exists(exp.is_truncated):
|
|
exp.is_truncated = full((batch,), True, device = device)
|
|
|
|
# convert to dictionary
|
|
|
|
exps_dict = [asdict(exp) for exp in exps]
|
|
|
|
values, tree_specs = zip(*[tree_flatten(exp_dict) for exp_dict in exps_dict])
|
|
|
|
tree_spec = first(tree_specs)
|
|
|
|
all_field_values = list(zip(*values))
|
|
|
|
# an assert to make sure all fields are either all tensors, or a single matching value (for step size, agent index etc) - can change this later
|
|
|
|
assert all([
|
|
all([is_tensor(v) for v in field_values]) or len(set(field_values)) == 1
|
|
for field_values in all_field_values
|
|
])
|
|
|
|
concatted = []
|
|
|
|
for field_values in all_field_values:
|
|
|
|
if is_tensor(first(field_values)):
|
|
|
|
field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2))
|
|
|
|
new_field_value = cat(field_values)
|
|
else:
|
|
new_field_value = first(list(set(field_values)))
|
|
|
|
concatted.append(new_field_value)
|
|
|
|
# return experience
|
|
|
|
concat_exp_dict = tree_unflatten(concatted, tree_spec)
|
|
|
|
return Experience(**concat_exp_dict)
|
|
|
|
# helpers
|
|
|
|
def exists(v):
|
|
return v is not None
|
|
|
|
def default(v, d):
|
|
return v if exists(v) else d
|
|
|
|
def first(arr):
|
|
return arr[0]
|
|
|
|
def xnor(x, y):
|
|
return not (x ^ y)
|
|
|
|
def has_at_least_one(*bools):
|
|
return sum([*map(int, bools)]) > 0
|
|
|
|
def ensure_tuple(t):
|
|
return (t,) if not isinstance(t, tuple) else t
|
|
|
|
def divisible_by(num, den):
|
|
return (num % den) == 0
|
|
|
|
def sample_prob(prob):
|
|
return random() < prob
|
|
|
|
def is_power_two(num):
|
|
return log2(num).is_integer()
|
|
|
|
# tensor helpers
|
|
|
|
def is_empty(t):
|
|
return t.numel() == 0
|
|
|
|
def lens_to_mask(t, max_len = None):
|
|
if not exists(max_len):
|
|
max_len = t.amax().item()
|
|
|
|
device = t.device
|
|
seq = torch.arange(max_len, device = device)
|
|
|
|
return einx.less('j, i -> i j', seq, t)
|
|
|
|
def log(t, eps = 1e-20):
|
|
return t.clamp(min = eps).log()
|
|
|
|
def safe_cat(tensors, dim):
|
|
tensors = [*filter(exists, tensors)]
|
|
|
|
if len(tensors) == 0:
|
|
return None
|
|
elif len(tensors) == 1:
|
|
return tensors[0]
|
|
|
|
return cat(tensors, dim = dim)
|
|
|
|
def safe_squeeze_first(t):
|
|
if not exists(t):
|
|
return None
|
|
|
|
if t.shape[0] != 1:
|
|
return t
|
|
|
|
return rearrange(t, '1 ... -> ...')
|
|
|
|
def gumbel_noise(t):
|
|
noise = torch.rand_like(t)
|
|
return -log(-log(noise))
|
|
|
|
def gumbel_sample(
|
|
t,
|
|
temperature = 1.,
|
|
dim = -1,
|
|
keepdim = False,
|
|
eps = 1e-10
|
|
):
|
|
noised = (t / max(temperature, eps)) + gumbel_noise(t)
|
|
return noised.argmax(dim = dim, keepdim = keepdim)
|
|
|
|
def pack_one(t, pattern):
|
|
packed, packed_shape = pack([t], pattern)
|
|
|
|
def inverse(out, inv_pattern = None):
|
|
inv_pattern = default(inv_pattern, pattern)
|
|
return first(unpack(out, packed_shape, inv_pattern))
|
|
|
|
return packed, inverse
|
|
|
|
def pad_at_dim(
|
|
t,
|
|
pad: tuple[int, int],
|
|
dim = -1,
|
|
value = 0.
|
|
):
|
|
if pad == (0, 0):
|
|
return t
|
|
|
|
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
|
zeros = ((0, 0) * dims_from_right)
|
|
return F.pad(t, (*zeros, *pad), value = value)
|
|
|
|
def pad_to_len(t, target_len, *, dim):
|
|
curr_len = t.shape[dim]
|
|
|
|
if curr_len >= target_len:
|
|
return t
|
|
|
|
return pad_at_dim(t, (0, target_len - curr_len), dim = dim)
|
|
|
|
def pad_tensors_at_dim_to_max_len(
|
|
tensors: list[Tensor],
|
|
dims: tuple[int, ...]
|
|
):
|
|
for dim in dims:
|
|
if dim >= first(tensors).ndim:
|
|
continue
|
|
|
|
max_time = max([t.shape[dim] for t in tensors])
|
|
tensors = [pad_to_len(t, max_time, dim = dim) for t in tensors]
|
|
|
|
return tensors
|
|
|
|
def align_dims_left(t, aligned_to):
|
|
shape = t.shape
|
|
num_right_dims = aligned_to.ndim - t.ndim
|
|
|
|
if num_right_dims < 0:
|
|
return
|
|
|
|
return t.reshape(*shape, *((1,) * num_right_dims))
|
|
|
|
def l2norm(t):
|
|
return F.normalize(t, dim = -1, p = 2)
|
|
|
|
def softclamp(t, value = 50.):
|
|
return (t / value).tanh() * value
|
|
|
|
def create_multi_token_prediction_targets(
|
|
t, # (b t ...)
|
|
steps_future,
|
|
|
|
): # (b t-1 steps ...), (b t-1 steps) - targets and the mask, where mask is False for padding
|
|
|
|
batch, seq_len, device = *t.shape[:2], t.device
|
|
|
|
batch_arange = arange(batch, device = device)
|
|
seq_arange = arange(seq_len, device = device)
|
|
steps_arange = arange(steps_future, device = device)
|
|
|
|
indices = add('t, steps -> t steps', seq_arange, steps_arange)
|
|
mask = indices < seq_len
|
|
|
|
batch_arange = rearrange(batch_arange, 'b -> b 1 1')
|
|
|
|
indices[~mask] = 0
|
|
mask = repeat(mask, 't steps -> b t steps', b = batch)
|
|
|
|
out = t[batch_arange, indices]
|
|
|
|
return out, mask
|
|
|
|
# loss related
|
|
|
|
class LossNormalizer(Module):
|
|
|
|
# the authors mentioned the need for loss normalization in the dynamics transformer
|
|
|
|
def __init__(
|
|
self,
|
|
num_losses: int,
|
|
beta = 0.95,
|
|
eps = 1e-6
|
|
):
|
|
super().__init__()
|
|
self.register_buffer('exp_avg_sq', torch.ones(num_losses))
|
|
self.beta = beta
|
|
self.eps = eps
|
|
|
|
def forward(
|
|
self,
|
|
losses: Tensor | list[Tensor] | dict[str, Tensor],
|
|
update_ema = None
|
|
):
|
|
exp_avg_sq, beta = self.exp_avg_sq, self.beta
|
|
update_ema = default(update_ema, self.training)
|
|
|
|
# get the rms value - as mentioned at the end of section 3 in the paper
|
|
|
|
rms = exp_avg_sq.sqrt()
|
|
|
|
if update_ema:
|
|
decay = 1. - beta
|
|
|
|
# update the ema
|
|
|
|
exp_avg_sq.lerp_(losses.detach().square(), decay)
|
|
|
|
# then normalize
|
|
|
|
assert losses.numel() == rms.numel()
|
|
|
|
normed_losses = losses / rms.clamp(min = self.eps)
|
|
|
|
return normed_losses
|
|
|
|
class LPIPSLoss(Module):
|
|
def __init__(
|
|
self,
|
|
vgg: Module | None = None,
|
|
vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
|
|
sampled_frames = 1
|
|
):
|
|
super().__init__()
|
|
|
|
if not exists(vgg):
|
|
vgg = torchvision.models.vgg16(weights = vgg_weights)
|
|
vgg.classifier = Sequential(*vgg.classifier[:-2])
|
|
|
|
self.vgg = [vgg]
|
|
self.sampled_frames = sampled_frames
|
|
|
|
def forward(
|
|
self,
|
|
pred,
|
|
data,
|
|
):
|
|
batch, device, is_video = pred.shape[0], pred.device, pred.ndim == 5
|
|
|
|
vgg, = self.vgg
|
|
vgg = vgg.to(data.device)
|
|
|
|
# take care of sampling random frames of the video
|
|
|
|
if is_video:
|
|
pred, data = tuple(rearrange(t, 'b c t ... -> b t c ...') for t in (pred, data))
|
|
|
|
# batch randperm
|
|
|
|
batch_randperm = randn(pred.shape[:2], device = pred.device).argsort(dim = -1)
|
|
rand_frames = batch_randperm[..., :self.sampled_frames]
|
|
|
|
batch_arange = arange(batch, device = device)
|
|
batch_arange = rearrange(batch_arange, '... -> ... 1')
|
|
|
|
pred, data = tuple(t[batch_arange, rand_frames] for t in (pred, data))
|
|
|
|
# fold sampled frames into batch
|
|
|
|
pred, data = tuple(rearrange(t, 'b t c ... -> (b t) c ...') for t in (pred, data))
|
|
|
|
pred_embed, embed = tuple(vgg(t) for t in (pred, data))
|
|
|
|
return F.mse_loss(embed, pred_embed)
|
|
|
|
def ramp_weight(times, slope = 0.9, intercept = 0.1):
|
|
# equation (8) paper, their "ramp" loss weighting
|
|
return slope * times + intercept
|
|
|
|
# reinforcement learning related
|
|
|
|
# rewards
|
|
|
|
class SymExpTwoHot(Module):
|
|
def __init__(
|
|
self,
|
|
reward_range = (-20., 20.),
|
|
num_bins = 255,
|
|
learned_embedding = False,
|
|
dim_embed = None,
|
|
):
|
|
super().__init__()
|
|
|
|
min_value, max_value = reward_range
|
|
values = linspace(min_value, max_value, num_bins)
|
|
values = values.sign() * (torch.exp(values.abs()) - 1.)
|
|
|
|
self.reward_range = reward_range
|
|
self.num_bins = num_bins
|
|
self.register_buffer('bin_values', values)
|
|
|
|
# take care of a reward embedding
|
|
# for an improvisation where agent tokens can also see the past rewards - it makes sense that this information should not be thrown out, a la Decision Transformer
|
|
|
|
self.learned_embedding = learned_embedding
|
|
|
|
if learned_embedding:
|
|
assert exists(dim_embed)
|
|
self.bin_embeds = nn.Embedding(num_bins, dim_embed)
|
|
|
|
@property
|
|
def device(self):
|
|
return self.bin_values.device
|
|
|
|
def embed(
|
|
self,
|
|
two_hot_encoding,
|
|
):
|
|
assert self.learned_embedding, f'can only embed if `learned_embedding` is True'
|
|
|
|
weights, bin_indices = two_hot_encoding.topk(k = 2, dim = -1)
|
|
|
|
two_embeds = self.bin_embeds(bin_indices)
|
|
|
|
return einsum(two_embeds, weights, '... two d, ... two -> ... d')
|
|
|
|
def bins_to_scalar_value(
|
|
self,
|
|
logits, # (... l)
|
|
normalize = False
|
|
):
|
|
two_hot_encoding = logits.softmax(dim = -1) if normalize else logits
|
|
return einsum(two_hot_encoding, self.bin_values, '... l, l -> ...')
|
|
|
|
def forward(
|
|
self,
|
|
values
|
|
):
|
|
bin_values = self.bin_values
|
|
min_bin_value, max_bin_value = self.bin_values[0], self.bin_values[-1]
|
|
|
|
values, inverse_pack = pack_one(values, '*')
|
|
num_values = values.shape[0]
|
|
|
|
values = values.clamp(min = min_bin_value, max = max_bin_value)
|
|
|
|
indices = torch.searchsorted(self.bin_values, values)
|
|
|
|
# fetch the closest two indices (two-hot encoding)
|
|
|
|
left_indices = (indices - 1).clamp(min = 0)
|
|
right_indices = left_indices + 1
|
|
|
|
left_indices, right_indices = tuple(rearrange(t, '... -> ... 1') for t in (left_indices, right_indices))
|
|
|
|
# fetch the left and right values for the consecutive indices
|
|
|
|
left_values = self.bin_values[left_indices]
|
|
right_values = self.bin_values[right_indices]
|
|
|
|
# calculate the left and right values by the distance to the left and right
|
|
|
|
values = rearrange(values, '... -> ... 1')
|
|
total_distance = right_values - left_values
|
|
|
|
left_logit_value = (right_values - values) / total_distance
|
|
right_logit_value = 1. - left_logit_value
|
|
|
|
# set the left and right values (two-hot)
|
|
|
|
encoded = torch.zeros((num_values, self.num_bins), device = self.device)
|
|
|
|
encoded.scatter_(-1, left_indices, left_logit_value)
|
|
encoded.scatter_(-1, right_indices, right_logit_value)
|
|
|
|
return inverse_pack(encoded, '* l')
|
|
|
|
# action related
|
|
|
|
ActionEmbeds = namedtuple('ActionEmbed', ('discrete', 'continuous'))
|
|
|
|
class ActionEmbedder(Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
*,
|
|
num_discrete_actions: int | tuple[int, ...] = 0,
|
|
num_continuous_actions = 0,
|
|
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None,
|
|
can_unembed = False,
|
|
unembed_dim = None,
|
|
num_unembed_preds = 1,
|
|
squeeze_unembed_preds = True # will auto-squeeze if prediction is just 1
|
|
):
|
|
super().__init__()
|
|
|
|
# handle discrete actions
|
|
|
|
num_discrete_actions = tensor(ensure_tuple(num_discrete_actions))
|
|
total_discrete_actions = num_discrete_actions.sum().item()
|
|
|
|
self.num_discrete_action_types = len(num_discrete_actions)
|
|
self.discrete_action_embed = Embedding(total_discrete_actions, dim)
|
|
|
|
self.register_buffer('num_discrete_actions', num_discrete_actions, persistent = False)
|
|
|
|
# continuous actions
|
|
|
|
self.num_continuous_action_types = num_continuous_actions
|
|
self.continuous_action_embed = Embedding(num_continuous_actions, dim)
|
|
|
|
self.continuous_need_norm = exists(continuous_norm_stats)
|
|
|
|
if self.continuous_need_norm:
|
|
self.register_buffer('continuous_norm_stats', tensor(continuous_norm_stats))
|
|
|
|
# defaults
|
|
|
|
self.register_buffer('default_discrete_action_types', arange(self.num_discrete_action_types), persistent = False)
|
|
self.register_buffer('default_continuous_action_types', arange(self.num_continuous_action_types), persistent = False)
|
|
|
|
# calculate offsets
|
|
|
|
offsets = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
|
|
self.register_buffer('discrete_action_offsets', offsets, persistent = False)
|
|
|
|
# unembedding
|
|
|
|
self.can_unembed = can_unembed
|
|
|
|
self.num_unembed_preds = num_unembed_preds
|
|
self.squeeze_unembed_preds = squeeze_unembed_preds
|
|
|
|
if not can_unembed:
|
|
return
|
|
|
|
unembed_dim = default(unembed_dim, dim)
|
|
self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, num_unembed_preds, unembed_dim) * 1e-2)
|
|
|
|
discrete_action_index = arange(total_discrete_actions)
|
|
|
|
padded_num_discrete_actions = F.pad(num_discrete_actions, (1, 0), value = 0)
|
|
exclusive_cumsum = padded_num_discrete_actions.cumsum(dim = -1)
|
|
|
|
discrete_action_mask = (
|
|
einx.greater_equal('j, i -> i j', discrete_action_index, exclusive_cumsum[:-1]) &
|
|
einx.less('j, i -> i j', discrete_action_index, exclusive_cumsum[1:])
|
|
)
|
|
|
|
self.register_buffer('discrete_action_mask', discrete_action_mask, persistent = False)
|
|
|
|
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, num_unembed_preds, unembed_dim, 2) * 1e-2)
|
|
|
|
def embed_parameters(self):
|
|
return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
|
|
|
|
def unembed_parameters(self):
|
|
return set([self.discrete_action_unembed, self.continuous_action_unembed])
|
|
|
|
@property
|
|
def device(self):
|
|
return self.discrete_action_offsets.device
|
|
|
|
@property
|
|
def has_actions(self):
|
|
return self.num_discrete_action_types > 0 or self.num_continuous_action_types > 0
|
|
|
|
def cast_action_types(
|
|
self,
|
|
action_types = None
|
|
):
|
|
if exists(action_types) and not is_tensor(action_types):
|
|
if isinstance(action_types, int):
|
|
action_types = (action_types,)
|
|
|
|
action_types = tensor(action_types, device = self.device)
|
|
|
|
return action_types
|
|
|
|
def unembed(
|
|
self,
|
|
embeds, # (... d)
|
|
discrete_action_types = None, # (na)
|
|
continuous_action_types = None, # (na)
|
|
return_split_discrete = False,
|
|
pred_head_index: int | Tensor | None = None
|
|
|
|
): # (... discrete_na), (... continuous_na 2)
|
|
|
|
device = embeds.device
|
|
|
|
assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init'
|
|
|
|
# handle only one prediction head during inference
|
|
|
|
if exists(pred_head_index) and isinstance(pred_head_index, int):
|
|
pred_head_index = tensor(pred_head_index, device = device)
|
|
|
|
# if pred_head_index given as a solo int, just assume we want to squeeze out the prediction head dimension
|
|
|
|
squeeze_one_pred_head = exists(pred_head_index) and pred_head_index.ndim == 0
|
|
|
|
# get action types
|
|
|
|
discrete_action_types, continuous_action_types = tuple(self.cast_action_types(t) for t in (discrete_action_types, continuous_action_types))
|
|
|
|
# discrete actions
|
|
|
|
discrete_action_logits = None
|
|
|
|
if self.num_discrete_action_types > 0:
|
|
|
|
discrete_action_unembed = self.discrete_action_unembed
|
|
|
|
if exists(discrete_action_types):
|
|
discrete_action_mask = self.discrete_action_mask[discrete_action_types].any(dim = 0)
|
|
|
|
discrete_action_unembed = discrete_action_unembed[discrete_action_mask]
|
|
|
|
if exists(pred_head_index):
|
|
discrete_action_unembed = discrete_action_unembed.index_select(1, pred_head_index)
|
|
|
|
discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na mtp d -> mtp ... na')
|
|
|
|
if self.squeeze_unembed_preds or squeeze_one_pred_head:
|
|
discrete_action_logits = safe_squeeze_first(discrete_action_logits)
|
|
|
|
# whether to split the discrete action logits by the number of actions per action type
|
|
|
|
if exists(discrete_action_logits) and return_split_discrete:
|
|
|
|
split_sizes = self.num_discrete_actions[discrete_action_types] if exists(discrete_action_types) else self.num_discrete_actions
|
|
|
|
discrete_action_logits = discrete_action_logits.split(split_sizes.tolist(), dim = -1)
|
|
|
|
# continuous actions
|
|
|
|
continuous_action_mean_log_var = None
|
|
|
|
if self.num_continuous_action_types > 0:
|
|
|
|
continuous_action_unembed = self.continuous_action_unembed
|
|
|
|
if exists(continuous_action_types):
|
|
continuous_action_unembed = continuous_action_unembed[continuous_action_types]
|
|
|
|
if exists(pred_head_index):
|
|
continuous_action_unembed = continuous_action_unembed.index_select(1, pred_head_index)
|
|
|
|
continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na mtp d two -> mtp ... na two')
|
|
|
|
if self.squeeze_unembed_preds or squeeze_one_pred_head:
|
|
continuous_action_mean_log_var = safe_squeeze_first(continuous_action_mean_log_var)
|
|
|
|
return discrete_action_logits, continuous_action_mean_log_var
|
|
|
|
def sample(
|
|
self,
|
|
embed,
|
|
discrete_temperature = 1.,
|
|
continuous_temperature = 1.,
|
|
inverse_norm_continuous = None,
|
|
pred_head_index: int | Tensor | None = None,
|
|
squeeze = True,
|
|
**kwargs
|
|
):
|
|
inverse_norm_continuous = default(inverse_norm_continuous, self.continuous_need_norm)
|
|
|
|
discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, pred_head_index = pred_head_index, **kwargs)
|
|
|
|
sampled_discrete = sampled_continuous = None
|
|
|
|
if exists(discrete_logits):
|
|
sampled_discrete = []
|
|
|
|
for one_discrete_logits in discrete_logits:
|
|
sampled_discrete.append(gumbel_sample(one_discrete_logits, temperature = discrete_temperature, keepdim = True))
|
|
|
|
sampled_discrete = cat(sampled_discrete, dim = -1)
|
|
|
|
if exists(continuous_mean_log_var):
|
|
mean, log_var = continuous_mean_log_var.unbind(dim = -1)
|
|
std = (0.5 * log_var).exp()
|
|
|
|
sampled_continuous = mean + std * torch.randn_like(mean) * continuous_temperature
|
|
|
|
# maybe inverse norm
|
|
|
|
if inverse_norm_continuous:
|
|
norm_mean, norm_std = self.continuous_norm_stats.unbind(dim = -1)
|
|
sampled_continuous = (sampled_continuous * norm_std) + norm_mean
|
|
|
|
return sampled_discrete, sampled_continuous
|
|
|
|
def log_probs(
|
|
self,
|
|
embeds, # (... d)
|
|
discrete_targets = None, # (... na)
|
|
continuous_targets = None, # (... na)
|
|
discrete_action_types = None, # (na)
|
|
continuous_action_types = None, # (na)
|
|
pred_head_index: int | Tensor | None = None,
|
|
parallel_discrete_calc = None,
|
|
return_entropies = False
|
|
):
|
|
parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
|
|
|
|
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, pred_head_index = pred_head_index, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
|
|
|
|
# discrete
|
|
|
|
discrete_log_probs = None
|
|
discrete_entropies = None
|
|
|
|
if exists(discrete_targets):
|
|
|
|
if parallel_discrete_calc:
|
|
# use nested tensors
|
|
|
|
jagged_dims = tuple(t.shape[-1] for t in discrete_action_logits)
|
|
|
|
discrete_action_logits = cat(discrete_action_logits, dim = -1)
|
|
|
|
discrete_action_logits, inverse_pack_lead_dims = pack_one(discrete_action_logits, '* l')
|
|
batch = discrete_action_logits.shape[0]
|
|
|
|
discrete_action_logits = rearrange(discrete_action_logits, 'b l -> (b l)')
|
|
|
|
# to nested tensor
|
|
|
|
nested_logits = nested_tensor(discrete_action_logits.split(jagged_dims * batch), layout = torch.jagged, device = self.device, requires_grad = True)
|
|
|
|
prob = nested_logits.softmax(dim = -1)
|
|
|
|
log_probs = log(prob)
|
|
|
|
# maybe entropy
|
|
|
|
if return_entropies:
|
|
discrete_entropies = (-prob * log_probs).sum(dim = -1, keepdim = True)
|
|
discrete_entropies = cat(discrete_entropies.unbind())
|
|
discrete_entropies = rearrange(discrete_entropies, '(b na) -> b na', b = batch)
|
|
|
|
discrete_entropies = inverse_pack_lead_dims(discrete_entropies, '* na')
|
|
|
|
# back to regular tensor
|
|
|
|
log_probs = cat(log_probs.unbind())
|
|
log_probs = rearrange(log_probs, '(b l) -> b l', b = batch)
|
|
|
|
log_probs = inverse_pack_lead_dims(log_probs)
|
|
|
|
# get indices to gather
|
|
|
|
discrete_action_types = default(discrete_action_types, self.default_discrete_action_types)
|
|
|
|
num_discrete_actions = self.num_discrete_actions[discrete_action_types]
|
|
|
|
offset = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
|
|
log_prob_indices = discrete_targets + offset
|
|
|
|
# gather
|
|
|
|
discrete_log_probs = log_probs.gather(-1, log_prob_indices)
|
|
|
|
else:
|
|
discrete_log_probs = []
|
|
discrete_entropies = []
|
|
|
|
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
|
|
|
|
one_discrete_probs = one_discrete_action_logit.softmax(dim = -1)
|
|
one_discrete_log_probs = log(one_discrete_probs)
|
|
one_discrete_target = rearrange(one_discrete_target, '... -> ... 1')
|
|
|
|
log_prob = one_discrete_log_probs.gather(-1, one_discrete_target)
|
|
discrete_log_probs.append(log_prob)
|
|
|
|
if return_entropies:
|
|
entropy = (-one_discrete_probs * one_discrete_log_probs).sum(dim = -1)
|
|
discrete_entropies.append(entropy)
|
|
|
|
discrete_log_probs = cat(discrete_log_probs, dim = -1)
|
|
|
|
if return_entropies:
|
|
discrete_entropies = stack(discrete_entropies, dim = -1)
|
|
|
|
# continuous
|
|
|
|
continuous_log_probs = None
|
|
continuous_entropies = None
|
|
|
|
if exists(continuous_targets):
|
|
mean, log_var = continuous_action_mean_log_var.unbind(dim = -1)
|
|
std = (0.5 * log_var).exp()
|
|
|
|
distr = Normal(mean, std)
|
|
continuous_log_probs = distr.log_prob(continuous_targets)
|
|
|
|
if return_entropies:
|
|
continuous_entropies = distr.entropy()
|
|
|
|
log_probs = (discrete_log_probs, continuous_log_probs)
|
|
|
|
if not return_entropies:
|
|
return log_probs
|
|
|
|
entropies = (discrete_entropies, continuous_entropies)
|
|
|
|
return log_probs, entropies
|
|
|
|
def forward(
|
|
self,
|
|
*,
|
|
discrete_actions = None, # (... na)
|
|
continuous_actions = None, # (... na)
|
|
discrete_action_types = None, # (na)
|
|
continuous_action_types = None, # (na)
|
|
return_sum_pooled_embeds = True
|
|
):
|
|
|
|
discrete_embeds = continuous_embeds = None
|
|
|
|
if exists(discrete_actions):
|
|
|
|
discrete_action_types = default(discrete_action_types, self.default_discrete_action_types)
|
|
|
|
discrete_action_types = self.cast_action_types(discrete_action_types)
|
|
|
|
offsets = self.discrete_action_offsets[discrete_action_types]
|
|
|
|
assert offsets.shape[-1] == discrete_actions.shape[-1], 'mismatched number of discrete actions'
|
|
|
|
# offset the discrete actions based on the action types passed in (by default all discrete actions) and the calculated offset
|
|
|
|
discrete_actions_offsetted = add('... na, na', discrete_actions, offsets)
|
|
discrete_embeds = self.discrete_action_embed(discrete_actions_offsetted)
|
|
|
|
if exists(continuous_actions):
|
|
continuous_action_types = default(continuous_action_types, self.default_continuous_action_types)
|
|
|
|
continuous_action_types = self.cast_action_types(continuous_action_types)
|
|
|
|
assert continuous_action_types.shape[-1] == continuous_actions.shape[-1], 'mismatched number of continuous actions'
|
|
|
|
continuous_action_embed = self.continuous_action_embed(continuous_action_types)
|
|
|
|
# maybe normalization
|
|
|
|
if self.continuous_need_norm:
|
|
norm_mean, norm_std = self.continuous_norm_stats.unbind(dim = -1)
|
|
continuous_actions = (continuous_actions - norm_mean) / norm_std.clamp(min = 1e-6)
|
|
|
|
# continuous embed is just the selected continuous action type with the scale
|
|
|
|
continuous_embeds = multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions)
|
|
|
|
# return not pooled
|
|
|
|
if not return_sum_pooled_embeds:
|
|
return ActionEmbeds(discrete_embeds, continuous_embeds)
|
|
|
|
# handle sum pooling, which is what they did in the paper for all the actions
|
|
|
|
pooled = 0.
|
|
|
|
if exists(discrete_embeds):
|
|
pooled = pooled + reduce(discrete_embeds, '... na d -> ... d', 'sum')
|
|
|
|
if exists(continuous_embeds):
|
|
pooled = pooled + reduce(continuous_embeds, '... na d -> ... d', 'sum')
|
|
|
|
return pooled
|
|
|
|
# generalized advantage estimate
|
|
|
|
@torch.no_grad()
|
|
def calc_gae(
|
|
rewards,
|
|
values,
|
|
masks = None,
|
|
gamma = 0.99,
|
|
lam = 0.95,
|
|
use_accelerated = None
|
|
):
|
|
assert values.shape[-1] == rewards.shape[-1]
|
|
use_accelerated = default(use_accelerated, rewards.is_cuda)
|
|
|
|
if not exists(masks):
|
|
masks = torch.ones_like(values)
|
|
|
|
values = F.pad(values, (0, 1), value = 0.)
|
|
values, values_next = values[..., :-1], values[..., 1:]
|
|
|
|
delta = rewards + gamma * values_next * masks - values
|
|
gates = gamma * lam * masks
|
|
|
|
scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
|
|
|
|
gae = scan(gates, delta)
|
|
|
|
returns = gae + values
|
|
|
|
return returns
|
|
|
|
# rotary embeddings for time
|
|
|
|
class Rotary1D(Module):
|
|
def __init__(
|
|
self,
|
|
dim_head,
|
|
theta = 10000.
|
|
):
|
|
super().__init__()
|
|
inv_freq = 1.0 / (theta ** (arange(0, dim_head, 2).float() / dim_head))
|
|
self.register_buffer('inv_freq', inv_freq)
|
|
|
|
def forward(
|
|
self,
|
|
seq_len,
|
|
offset = 0
|
|
):
|
|
device, dtype = self.inv_freq.device, self.inv_freq.dtype
|
|
|
|
t = torch.arange(seq_len, device = device).type(dtype) + offset
|
|
freqs = einsum(t, self.inv_freq, 'i, j -> i j')
|
|
|
|
return cat((freqs, freqs), dim = -1)
|
|
|
|
|
|
def apply_rotations(
|
|
rotations, # (h n d) | (n d)
|
|
t # (b h n d)
|
|
):
|
|
|
|
heads, seq_len, dtype = *t.shape[1:3], t.dtype
|
|
|
|
rotations_seq_len = rotations.shape[-2]
|
|
|
|
# handle kv caching with rotations
|
|
|
|
if rotations_seq_len > seq_len:
|
|
rotations = rotations[-seq_len:]
|
|
|
|
# precision
|
|
|
|
t = t.float()
|
|
|
|
# handle gqa for rotary
|
|
|
|
if rotations.ndim == 3 and rotations.shape[0] < heads:
|
|
rotary_heads = rotations.shape[0]
|
|
|
|
assert divisible_by(heads, rotary_heads)
|
|
groups = heads // rotary_heads
|
|
rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups)
|
|
|
|
x1, x2 = t.chunk(2, dim = -1)
|
|
rotated_half_t = cat((-x2, x1), dim = -1)
|
|
|
|
# rotate in the positions
|
|
|
|
rotated = t * rotations.cos() + rotated_half_t * rotations.sin()
|
|
return rotated.type(dtype)
|
|
|
|
# multi-head rmsnorm
|
|
|
|
class MultiHeadRMSNorm(Module):
|
|
def __init__(
|
|
self,
|
|
dim_head,
|
|
heads = 8
|
|
):
|
|
super().__init__()
|
|
self.scale = dim_head ** 0.5
|
|
self.gamma = Parameter(torch.zeros(heads, dim_head)) # weight decay friendly
|
|
|
|
def forward(
|
|
self,
|
|
x # (b h n d)
|
|
):
|
|
normed = l2norm(x)
|
|
scale = (self.gamma + 1.) * self.scale
|
|
return multiply('... h n d, h d', normed, scale)
|
|
|
|
# naive attend
|
|
|
|
def naive_attend(
|
|
q, k, v,
|
|
softclamp_value = None,
|
|
scale = None,
|
|
causal = False,
|
|
causal_block_size = 1,
|
|
mask = None
|
|
):
|
|
|
|
if not exists(scale):
|
|
scale = q.shape[-1] ** -0.5
|
|
|
|
# grouped query attention
|
|
|
|
groups = q.shape[1] // k.shape[1]
|
|
|
|
q = rearrange(q, 'b (h g) ... -> b h g ...', g = groups)
|
|
|
|
# similarity
|
|
|
|
sim = einsum(q, k, 'b h g i d, b h j d -> b h g i j')
|
|
|
|
# scale and attention
|
|
|
|
sim = sim * scale
|
|
|
|
# softclamping a la gemma 3
|
|
|
|
if exists(softclamp_value):
|
|
sim = softclamp(sim, softclamp_value)
|
|
|
|
# masking
|
|
|
|
mask_value = -torch.finfo(sim.dtype).max
|
|
|
|
if exists(mask):
|
|
sim = sim.masked_fill(~mask, mask_value)
|
|
|
|
if causal:
|
|
is_blocked_causal = causal_block_size > 1
|
|
i, j = sim.shape[-2:]
|
|
|
|
if is_blocked_causal:
|
|
i = ceil(i / causal_block_size)
|
|
j = ceil(j / causal_block_size)
|
|
|
|
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
|
|
|
|
if causal_block_size > 1:
|
|
causal_mask = repeat(causal_mask, 'i j -> (i b1) (j b2)', b1 = causal_block_size, b2 = causal_block_size)
|
|
causal_mask = causal_mask[:sim.shape[-2], :sim.shape[-1]]
|
|
|
|
sim = sim.masked_fill(causal_mask, mask_value)
|
|
|
|
# attend
|
|
|
|
attn = sim.softmax(dim = -1)
|
|
|
|
# aggregate
|
|
|
|
out = einsum(attn, v, 'b h g i j, b h j d -> b h g i d')
|
|
|
|
# merge the groups
|
|
|
|
return rearrange(out, 'b h g i d -> b (h g) i d')
|
|
|
|
# flex attention related and factory function for attend depending on whether on cuda + flex attention available
|
|
|
|
def block_mask_causal(block_size):
|
|
|
|
def inner(b, h, q, k):
|
|
bq = q // block_size
|
|
bk = k // block_size
|
|
return bq >= bk
|
|
|
|
return inner
|
|
|
|
def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = False):
|
|
bq = q % seq_len
|
|
bk = k % seq_len
|
|
|
|
is_special_start_index = seq_len - num_tokens
|
|
|
|
q_is_special = q >= is_special_start_index
|
|
k_is_special = k >= is_special_start_index
|
|
|
|
if special_attend_only_itself:
|
|
out = ~(q_is_special & ~k_is_special) # modality attends to everything, but latent can only attend to itself (proposed attention pattern for encoder of video tokenizer)
|
|
else:
|
|
out = ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens
|
|
|
|
return out
|
|
|
|
def block_mask_special_tokens_right(
|
|
seq_len,
|
|
num_tokens
|
|
):
|
|
def inner(b, h, q, k):
|
|
return special_token_mask(q, k, seq_len, num_tokens)
|
|
return inner
|
|
|
|
def compose_mask(mask1, mask2):
|
|
def inner(b, h, q, k):
|
|
return mask1(b, h, q, k) & mask2(b, h, q, k)
|
|
|
|
return inner
|
|
|
|
def block_mask_noop(b, h, q, k):
|
|
return b >= 0
|
|
|
|
def score_mod_softclamp(value):
|
|
def inner(sim, b, h, q, k):
|
|
if not exists(value):
|
|
return sim
|
|
|
|
sim = sim / value
|
|
sim = torch.tanh(sim)
|
|
sim = sim * value
|
|
return sim
|
|
|
|
return inner
|
|
|
|
# factory for attend function
|
|
|
|
def get_attend_fn(
|
|
use_flex,
|
|
seq_len,
|
|
k_seq_len,
|
|
causal = False,
|
|
causal_block_size = 1,
|
|
softclamp_value = 50.,
|
|
num_special_tokens = 0, # special tokens are latents / agents
|
|
block_size_per_special = None, # defaults to k_seq_len
|
|
special_attend_only_itself = False, # by default, modality only attends to itself while special sees everything, but if turned True, will be the inverse - special can only attend to itself but modality can attend everything
|
|
device = None
|
|
):
|
|
block_size_per_special = default(block_size_per_special, k_seq_len)
|
|
|
|
if use_flex:
|
|
# flex pathway
|
|
|
|
block_mask_fn = block_mask_causal(causal_block_size) if causal else block_mask_noop
|
|
|
|
if num_special_tokens > 0:
|
|
special_block_mask = block_mask_special_tokens_right(block_size_per_special, num_special_tokens, special_attend_only_itself)
|
|
block_mask_fn = compose_mask(block_mask_fn, special_block_mask)
|
|
|
|
block_mask = create_block_mask(block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = k_seq_len)
|
|
|
|
score_mod = score_mod_softclamp(softclamp_value)
|
|
attend_fn = partial(flex_attention, block_mask = block_mask, score_mod = score_mod, enable_gqa = True)
|
|
else:
|
|
# naive pathway
|
|
|
|
mask = None
|
|
if num_special_tokens > 0:
|
|
q_seq = torch.arange(seq_len, device = device)[:, None]
|
|
k_seq = torch.arange(k_seq_len, device = device)[None, :]
|
|
|
|
mask = special_token_mask(q_seq, k_seq, block_size_per_special, num_special_tokens, special_attend_only_itself)
|
|
|
|
attend_fn = partial(naive_attend, causal = causal, causal_block_size = causal_block_size, mask = mask, softclamp_value = softclamp_value)
|
|
|
|
return attend_fn
|
|
|
|
# attention
|
|
|
|
class Attention(Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
dim_head = 64,
|
|
query_heads = None,
|
|
heads = 8,
|
|
pre_rmsnorm = True,
|
|
gate_values = True,
|
|
rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2
|
|
rmsnorm_key = True
|
|
):
|
|
super().__init__()
|
|
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
|
|
|
# setup grouped query attention
|
|
|
|
query_heads = default(query_heads, heads)
|
|
assert query_heads >= heads and divisible_by(query_heads, heads)
|
|
|
|
# scaling, splitting and merging of heads
|
|
|
|
self.split_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
|
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
|
|
dim_q_inner = dim_head * query_heads
|
|
dim_kv_inner = dim_head * heads
|
|
|
|
self.to_q = LinearNoBias(dim, dim_q_inner)
|
|
self.to_k = LinearNoBias(dim, dim_kv_inner)
|
|
self.to_v = LinearNoBias(dim, dim_kv_inner)
|
|
self.to_out = LinearNoBias(dim_q_inner, dim)
|
|
|
|
# alphafold gating per head, for attending to nothing
|
|
|
|
self.to_gates = None
|
|
|
|
if gate_values:
|
|
self.to_gates = Sequential(
|
|
LinearNoBias(dim, query_heads),
|
|
Rearrange('b n h -> b h n 1'),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
# stability related
|
|
|
|
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity()
|
|
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
|
|
|
|
def muon_parameters(self):
|
|
# omit the queries and keys for now given what we learned from kimi 2 paper
|
|
|
|
return [
|
|
*self.to_v.parameters(),
|
|
*self.to_out.parameters(),
|
|
]
|
|
|
|
def forward(
|
|
self,
|
|
tokens, # (b n d)
|
|
kv_cache = None,
|
|
return_kv_cache = False,
|
|
rotary_pos_emb = None,
|
|
attend_fn: Callable | None = None
|
|
):
|
|
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
|
|
|
tokens = self.norm(tokens)
|
|
|
|
q, k, v = (self.to_q(tokens), self.to_k(tokens), self.to_v(tokens))
|
|
|
|
# split heads
|
|
|
|
q, k, v = map(self.split_heads, (q, k, v))
|
|
|
|
# qk rmsnorm
|
|
|
|
q = self.q_heads_rmsnorm(q)
|
|
k = self.k_heads_rmsnorm(k)
|
|
|
|
# caching
|
|
|
|
if exists(kv_cache):
|
|
ck, cv = kv_cache
|
|
k = cat((ck, k), dim = -2)
|
|
v = cat((cv, v), dim = -2)
|
|
|
|
# rotary
|
|
|
|
if exists(rotary_pos_emb):
|
|
q = apply_rotations(rotary_pos_emb, q)
|
|
k = apply_rotations(rotary_pos_emb, k)
|
|
|
|
# attention
|
|
|
|
attend_fn = default(attend_fn, naive_attend)
|
|
|
|
out = attend_fn(q, k, v)
|
|
|
|
# gate values
|
|
|
|
if exists(self.to_gates):
|
|
gates = self.to_gates(tokens)
|
|
out = out * gates
|
|
|
|
# merge heads
|
|
|
|
out = self.merge_heads(out)
|
|
|
|
# combine heads
|
|
|
|
out = self.to_out(out)
|
|
|
|
out = inverse_packed_batch(out)
|
|
|
|
if not return_kv_cache:
|
|
return out
|
|
|
|
return out, stack((k, v))
|
|
|
|
# feedforward
|
|
|
|
class SwiGLUFeedforward(Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
expansion_factor = 4,
|
|
pre_rmsnorm = True
|
|
):
|
|
super().__init__()
|
|
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
|
|
|
dim_inner = int(dim * expansion_factor * 2 / 3)
|
|
|
|
self.proj_in = Linear(dim, dim_inner * 2)
|
|
self.proj_out = Linear(dim_inner, dim)
|
|
|
|
def muon_parameters(self):
|
|
return [
|
|
self.proj_in.weight,
|
|
self.proj_out.weight,
|
|
]
|
|
|
|
def forward(self, x):
|
|
x = self.norm(x)
|
|
|
|
x, gates = self.proj_in(x).chunk(2, dim = -1)
|
|
x = x * F.gelu(gates)
|
|
|
|
return self.proj_out(x)
|
|
|
|
# axial space time transformer
|
|
|
|
class AxialSpaceTimeTransformer(Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
depth,
|
|
attn_dim_head = 64,
|
|
attn_softclamp_value = 50.,
|
|
time_block_every = 4,
|
|
attn_kwargs: dict = dict(),
|
|
ff_kwargs: dict = dict(),
|
|
num_residual_streams = 1,
|
|
num_special_spatial_tokens = 1,
|
|
special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
|
|
final_norm = True
|
|
):
|
|
super().__init__()
|
|
assert depth >= time_block_every, f'depth must be at least {time_block_every}'
|
|
|
|
# hyper connections
|
|
|
|
hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
|
|
|
|
# attention
|
|
|
|
self.attn_softclamp_value = attn_softclamp_value
|
|
|
|
# attention masking
|
|
|
|
self.special_attend_only_itself = special_attend_only_itself
|
|
|
|
# time rotary embedding
|
|
|
|
self.time_rotary = Rotary1D(attn_dim_head)
|
|
|
|
# transformer
|
|
|
|
layers = []
|
|
is_time = []
|
|
|
|
for i in range(depth):
|
|
layer_index = i + 1
|
|
|
|
is_time_block = divisible_by(layer_index, time_block_every)
|
|
is_time.append(is_time_block)
|
|
|
|
rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity()
|
|
rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity()
|
|
|
|
layers.append(ModuleList([
|
|
rearrange_to_attend,
|
|
rearrange_from_attend,
|
|
hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
|
|
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
|
]))
|
|
|
|
self.layers = ModuleList(layers)
|
|
self.is_time = is_time
|
|
|
|
# final norm
|
|
|
|
self.final_norm = nn.RMSNorm(dim) if final_norm else nn.Identity()
|
|
|
|
# special tokens
|
|
|
|
self.num_special_spatial_tokens = num_special_spatial_tokens
|
|
|
|
def muon_parameters(self):
|
|
muon_params = []
|
|
|
|
for m in self.modules():
|
|
if isinstance(m, (Attention, SwiGLUFeedforward)):
|
|
muon_params.extend(m.muon_parameters())
|
|
|
|
return muon_params
|
|
|
|
def forward(
|
|
self,
|
|
tokens, # (b t s d)
|
|
kv_cache: Tensor | None = None, # (y 2 b h t d)
|
|
return_kv_cache = False
|
|
|
|
): # (b t s d) | (y 2 b h t d)
|
|
|
|
batch, time, space_seq_len, _, device = *tokens.shape, tokens.device
|
|
|
|
assert tokens.ndim == 4
|
|
|
|
# attend functions for space and time
|
|
|
|
use_flex = exists(flex_attention) and tokens.is_cuda
|
|
|
|
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
|
|
|
|
space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = self.num_special_spatial_tokens, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
|
|
|
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
|
|
|
# prepare cache
|
|
|
|
time_attn_kv_caches = []
|
|
|
|
has_kv_cache = exists(kv_cache)
|
|
|
|
|
|
if has_kv_cache:
|
|
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
|
|
|
rotary_seq_len = 1
|
|
rotary_pos_offset = past_tokens.shape[-2]
|
|
else:
|
|
rotary_seq_len = time
|
|
rotary_pos_offset = 0
|
|
|
|
kv_cache = default(kv_cache, (None,))
|
|
|
|
iter_kv_cache = iter(kv_cache)
|
|
|
|
# rotary
|
|
|
|
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
|
|
|
# attention
|
|
|
|
tokens = self.expand_streams(tokens)
|
|
|
|
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
|
|
|
tokens = pre_attn_rearrange(tokens)
|
|
|
|
# when is a axial time attention block, should be causal
|
|
|
|
attend_fn = time_attend if layer_is_time else space_attend
|
|
|
|
layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
|
|
|
|
# maybe past kv cache
|
|
|
|
maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None
|
|
|
|
# attention layer
|
|
|
|
tokens, next_kv_cache = attn(
|
|
tokens,
|
|
rotary_pos_emb = layer_rotary_pos_emb,
|
|
attend_fn = attend_fn,
|
|
kv_cache = maybe_kv_cache,
|
|
return_kv_cache = True
|
|
)
|
|
|
|
tokens = post_attn_rearrange(tokens)
|
|
|
|
# feedforward layer
|
|
|
|
tokens = ff(tokens)
|
|
|
|
# save kv cache if is time layer
|
|
|
|
if layer_is_time:
|
|
time_attn_kv_caches.append(next_kv_cache)
|
|
|
|
tokens = self.reduce_streams(tokens)
|
|
|
|
out = self.final_norm(tokens)
|
|
|
|
if has_kv_cache:
|
|
# just concat the past tokens back on for now, todo - clean up the logic
|
|
out = cat((past_tokens, out), dim = 1)
|
|
|
|
if not return_kv_cache:
|
|
return out
|
|
|
|
return out, stack(time_attn_kv_caches)
|
|
|
|
# video tokenizer
|
|
|
|
class VideoTokenizer(Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
dim_latent,
|
|
patch_size,
|
|
image_height = None,
|
|
image_width = None,
|
|
num_latent_tokens = 4,
|
|
encoder_depth = 4,
|
|
decoder_depth = 4,
|
|
time_block_every = 4,
|
|
attn_kwargs: dict = dict(),
|
|
attn_dim_head = 64,
|
|
attn_heads = 8,
|
|
attn_softclamp_value = 50.,
|
|
ff_kwargs: dict = dict(),
|
|
decoder_pos_mlp_depth = 2,
|
|
channels = 3,
|
|
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
|
lpips_loss_network: Module | None = None,
|
|
lpips_loss_weight = 0.2,
|
|
nd_rotary_kwargs: dict = dict(
|
|
rope_min_freq = 1.,
|
|
rope_max_freq = 10000.,
|
|
rope_p_zero_freqs = 0.
|
|
),
|
|
num_residual_streams = 1
|
|
):
|
|
super().__init__()
|
|
|
|
self.patch_size = patch_size
|
|
|
|
# special tokens
|
|
|
|
assert num_latent_tokens >= 1
|
|
self.num_latent_tokens = num_latent_tokens
|
|
self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2)
|
|
|
|
# hyper connections
|
|
|
|
hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
|
|
|
|
# mae masking - Kaiming He paper from long ago
|
|
|
|
self.per_image_patch_mask_prob = per_image_patch_mask_prob
|
|
self.mask_token = Parameter(randn(dim) * 1e-2)
|
|
|
|
# patch and unpatch
|
|
|
|
dim_patch = channels * patch_size ** 2
|
|
|
|
self.patch_to_tokens = Sequential(
|
|
Rearrange('b c t (h p1) (w p2) -> b t h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
|
Linear(dim_patch, dim)
|
|
)
|
|
|
|
self.tokens_to_patch = Sequential(
|
|
Linear(dim, dim_patch),
|
|
Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
|
|
)
|
|
|
|
# encoder space / time transformer
|
|
|
|
self.encoder_transformer = AxialSpaceTimeTransformer(
|
|
dim = dim,
|
|
depth = encoder_depth,
|
|
attn_dim_head = attn_dim_head,
|
|
attn_softclamp_value = attn_softclamp_value,
|
|
time_block_every = time_block_every,
|
|
num_special_spatial_tokens = num_latent_tokens,
|
|
num_residual_streams = num_residual_streams,
|
|
final_norm = True
|
|
)
|
|
|
|
# latents
|
|
|
|
self.encoded_to_latents = Sequential(
|
|
LinearNoBias(dim, dim_latent),
|
|
nn.Tanh(),
|
|
)
|
|
|
|
self.latents_to_decoder = LinearNoBias(dim_latent, dim)
|
|
|
|
# decoder
|
|
|
|
self.image_height = image_height
|
|
self.image_width = image_width
|
|
|
|
# parameterize the decoder positional embeddings for MAE style training so it can be resolution agnostic
|
|
|
|
self.to_decoder_pos_emb = create_mlp(
|
|
dim_in = 2,
|
|
dim = dim * 2,
|
|
dim_out = dim,
|
|
depth = decoder_pos_mlp_depth,
|
|
)
|
|
|
|
# decoder transformer
|
|
|
|
self.decoder_transformer = AxialSpaceTimeTransformer(
|
|
dim = dim,
|
|
depth = decoder_depth,
|
|
attn_dim_head = attn_dim_head,
|
|
attn_softclamp_value = attn_softclamp_value,
|
|
time_block_every = time_block_every,
|
|
num_special_spatial_tokens = num_latent_tokens,
|
|
num_residual_streams = num_residual_streams,
|
|
final_norm = True
|
|
)
|
|
|
|
# loss related
|
|
|
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
|
|
self.has_lpips_loss = lpips_loss_weight > 0.
|
|
self.lpips_loss_weight = lpips_loss_weight
|
|
|
|
if self.has_lpips_loss:
|
|
self.lpips = LPIPSLoss(lpips_loss_network)
|
|
|
|
@property
|
|
def device(self):
|
|
return self.zero.device
|
|
|
|
def muon_parameters(self):
|
|
return [
|
|
*self.encoder_transformer.muon_parameters(),
|
|
*self.decoder_transformer.muon_parameters()
|
|
]
|
|
|
|
@torch.no_grad()
|
|
def tokenize(
|
|
self,
|
|
video
|
|
):
|
|
self.eval()
|
|
return self.forward(video, return_latents = True)
|
|
|
|
def decode(
|
|
self,
|
|
latents, # (b t n d)
|
|
height = None,
|
|
width = None,
|
|
): # (b c t h w)
|
|
|
|
height = default(height, self.image_height)
|
|
width = default(width, self.image_width)
|
|
|
|
assert exists(height) and exists(width), f'image height and width need to be passed in when decoding latents'
|
|
|
|
batch, time, device = *latents.shape[:2], latents.device
|
|
|
|
use_flex = latents.is_cuda and exists(flex_attention)
|
|
|
|
num_patch_height = height // self.patch_size
|
|
num_patch_width = width // self.patch_size
|
|
|
|
# latents to tokens
|
|
|
|
latent_tokens = self.latents_to_decoder(latents)
|
|
|
|
# generate decoder positional embedding and concat the latent token
|
|
|
|
spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device)
|
|
spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device)
|
|
|
|
space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1)
|
|
|
|
decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor)
|
|
decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time)
|
|
|
|
tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d')
|
|
|
|
# decoder attention
|
|
|
|
tokens = self.decoder_transformer(tokens)
|
|
|
|
# unpack latents
|
|
|
|
tokens, latent_tokens = unpack(tokens, packed_latent_shape, 'b t * d')
|
|
|
|
# project back to patches
|
|
|
|
recon_video = self.tokens_to_patch(tokens)
|
|
|
|
return recon_video
|
|
|
|
def forward(
|
|
self,
|
|
video, # (b c t h w)
|
|
return_latents = False,
|
|
mask_patches = None,
|
|
return_all_losses = False
|
|
):
|
|
batch, _, time, height, width = video.shape
|
|
patch_size, device = self.patch_size, video.device
|
|
|
|
assert divisible_by(height, patch_size) and divisible_by(width, patch_size)
|
|
|
|
# to tokens
|
|
|
|
tokens = self.patch_to_tokens(video)
|
|
|
|
# get some dimensions
|
|
|
|
num_patch_height, num_patch_width, _ = tokens.shape[-3:]
|
|
|
|
# masking
|
|
|
|
mask_patches = default(mask_patches, self.training)
|
|
|
|
if mask_patches:
|
|
min_mask_prob, max_mask_prob = self.per_image_patch_mask_prob
|
|
|
|
mask_prob = torch.empty(tokens.shape[:2], device = tokens.device).uniform_(min_mask_prob, max_mask_prob) # (b t)
|
|
|
|
mask_prob = repeat(mask_prob, 'b t -> b t vh vw', vh = tokens.shape[2], vw = tokens.shape[3])
|
|
mask_patch = torch.bernoulli(mask_prob) == 1.
|
|
|
|
tokens = einx.where('..., d, ... d', mask_patch, self.mask_token, tokens)
|
|
|
|
# pack space
|
|
|
|
tokens, inverse_pack_space = pack_one(tokens, 'b t * d')
|
|
|
|
# add the latent
|
|
|
|
latents = repeat(self.latent_tokens, 'n d -> b t n d', b = tokens.shape[0], t = tokens.shape[1])
|
|
|
|
tokens, packed_latent_shape = pack((tokens, latents), 'b t * d')
|
|
|
|
# encoder attention
|
|
|
|
tokens = self.encoder_transformer(tokens)
|
|
|
|
# latent bottleneck
|
|
|
|
tokens, latents = unpack(tokens, packed_latent_shape, 'b t * d')
|
|
|
|
latents = self.encoded_to_latents(latents)
|
|
|
|
if return_latents:
|
|
return latents
|
|
|
|
recon_video = self.decode(latents, height = height, width = width)
|
|
|
|
# losses
|
|
|
|
recon_loss = F.mse_loss(video, recon_video)
|
|
|
|
lpips_loss = self.zero
|
|
|
|
if self.has_lpips_loss:
|
|
lpips_loss = self.lpips(video, recon_video)
|
|
|
|
# losses
|
|
|
|
total_loss = (
|
|
recon_loss +
|
|
lpips_loss * self.lpips_loss_weight
|
|
)
|
|
|
|
if not return_all_losses:
|
|
return total_loss
|
|
|
|
losses = (recon_loss, lpips_loss)
|
|
|
|
return total_loss, TokenizerLosses(losses)
|
|
|
|
# dynamics model, axial space-time transformer
|
|
|
|
class DynamicsWorldModel(Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
dim_latent,
|
|
video_tokenizer: VideoTokenizer | None = None,
|
|
max_steps = 64, # K_max in paper
|
|
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
|
num_spatial_tokens = 2, # latents projected to greater number of spatial tokens
|
|
num_latent_tokens = None,
|
|
num_agents = 1,
|
|
num_tasks = 0,
|
|
num_video_views = 1,
|
|
dim_proprio = None,
|
|
reward_encoder_kwargs: dict = dict(),
|
|
depth = 4,
|
|
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
|
time_block_every = 4, # every 4th block is time
|
|
attn_kwargs: dict = dict(
|
|
heads = 8,
|
|
),
|
|
transformer_kwargs: dict = dict(),
|
|
attn_dim_head = 64,
|
|
attn_softclamp_value = 50.,
|
|
ff_kwargs: dict = dict(),
|
|
loss_weight_fn: Callable = ramp_weight,
|
|
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
|
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
|
add_reward_embed_to_agent_token = False,
|
|
add_reward_embed_dropout = 0.1,
|
|
num_discrete_actions: int | tuple[int, ...] = 0,
|
|
num_continuous_actions = 0,
|
|
continuous_norm_stats = None,
|
|
multi_token_pred_len = 8,
|
|
value_head_mlp_depth = 3,
|
|
policy_head_mlp_depth = 3,
|
|
latent_flow_loss_weight = 1.,
|
|
reward_loss_weight: float | list[float] = 1.,
|
|
discrete_action_loss_weight: float | list[float] = 1.,
|
|
continuous_action_loss_weight: float | list[float] = 1.,
|
|
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
|
num_residual_streams = 1,
|
|
keep_reward_ema_stats = False,
|
|
reward_ema_decay = 0.998,
|
|
reward_quantile_filter = (0.05, 0.95),
|
|
gae_discount_factor = 0.997,
|
|
gae_lambda = 0.95,
|
|
ppo_eps_clip = 0.2,
|
|
value_clip = 0.4,
|
|
policy_entropy_weight = .01,
|
|
gae_use_accelerated = False
|
|
):
|
|
super().__init__()
|
|
|
|
# can accept raw video if tokenizer is passed in
|
|
|
|
self.video_tokenizer = video_tokenizer
|
|
|
|
if exists(video_tokenizer):
|
|
num_latent_tokens = default(num_latent_tokens, video_tokenizer.num_latent_tokens)
|
|
assert video_tokenizer.num_latent_tokens == num_latent_tokens, f'`num_latent_tokens` must be the same for the tokenizer and dynamics model'
|
|
|
|
assert exists(num_latent_tokens), '`num_latent_tokens` must be set'
|
|
|
|
# spatial
|
|
|
|
self.num_latent_tokens = num_latent_tokens
|
|
self.dim_latent = dim_latent
|
|
self.latent_shape = (num_latent_tokens, dim_latent)
|
|
|
|
if num_spatial_tokens >= num_latent_tokens:
|
|
assert divisible_by(num_spatial_tokens, num_latent_tokens)
|
|
|
|
expand_factor = num_spatial_tokens // num_latent_tokens
|
|
|
|
self.latents_to_spatial_tokens = Sequential(
|
|
Linear(dim_latent, dim * expand_factor),
|
|
Rearrange('... (s d) -> ... s d', s = expand_factor)
|
|
)
|
|
|
|
self.to_latent_pred = Sequential(
|
|
Reduce('b t v n s d -> b t v n d', 'mean'),
|
|
RMSNorm(dim),
|
|
LinearNoBias(dim, dim_latent)
|
|
)
|
|
|
|
else:
|
|
assert divisible_by(num_latent_tokens, num_spatial_tokens)
|
|
latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
|
|
|
|
self.latents_to_spatial_tokens = Sequential(
|
|
Rearrange('... n d -> ... (n d)'),
|
|
Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
|
|
Rearrange('... (s d) -> ... s d', s = num_spatial_tokens)
|
|
)
|
|
|
|
self.to_latent_pred = Sequential(
|
|
RMSNorm(dim),
|
|
LinearNoBias(dim, dim_latent * latent_tokens_to_space),
|
|
Rearrange('b t v s (n d) -> b t v (s n) d', n = latent_tokens_to_space)
|
|
)
|
|
|
|
# number of video views, for robotics, which could have third person + wrist camera at least
|
|
|
|
assert num_video_views >= 1
|
|
self.video_has_multi_view = num_video_views > 1
|
|
|
|
self.num_video_views = num_video_views
|
|
|
|
if self.video_has_multi_view:
|
|
self.view_emb = nn.Parameter(torch.randn(num_video_views, dim) * 1e-2)
|
|
|
|
# proprioception
|
|
|
|
self.has_proprio = exists(dim_proprio)
|
|
self.dim_proprio = dim_proprio
|
|
|
|
if self.has_proprio:
|
|
self.to_proprio_token = nn.Linear(dim_proprio, dim)
|
|
|
|
self.to_proprio_pred = Sequential(
|
|
RMSNorm(dim),
|
|
nn.Linear(dim, dim_proprio)
|
|
)
|
|
|
|
# register tokens
|
|
|
|
self.num_register_tokens = num_register_tokens
|
|
self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
|
|
|
# signal and step sizes
|
|
|
|
assert divisible_by(dim, 2)
|
|
dim_half = dim // 2
|
|
|
|
assert is_power_two(max_steps), '`max_steps` must be a power of 2'
|
|
self.max_steps = max_steps
|
|
self.num_step_sizes_log2 = int(log2(max_steps))
|
|
|
|
self.signal_levels_embed = nn.Embedding(max_steps, dim_half)
|
|
self.step_size_embed = nn.Embedding(self.num_step_sizes_log2, dim_half) # power of 2, so 1/1, 1/2, 1/4, 1/8 ... 1/Kmax
|
|
|
|
self.prob_no_shortcut_train = default(prob_no_shortcut_train, self.num_step_sizes_log2 ** -1.)
|
|
|
|
# loss related
|
|
|
|
self.pred_orig_latent = pred_orig_latent # x-space or v-space
|
|
self.loss_weight_fn = loss_weight_fn
|
|
|
|
# reinforcement related
|
|
|
|
# they sum all the actions into a single token
|
|
|
|
self.num_agents = num_agents
|
|
|
|
self.agent_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
|
self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
|
|
|
self.reward_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
|
|
|
self.num_tasks = num_tasks
|
|
self.task_embed = nn.Embedding(num_tasks, dim)
|
|
|
|
# learned set of latent genes
|
|
|
|
self.agent_has_genes = num_latent_genes > 0
|
|
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
|
|
|
|
# policy head
|
|
|
|
self.policy_head = create_mlp(
|
|
dim_in = dim,
|
|
dim = dim * 4,
|
|
dim_out = dim * 4,
|
|
depth = policy_head_mlp_depth
|
|
)
|
|
|
|
# action embedder
|
|
|
|
self.action_embedder = ActionEmbedder(
|
|
dim = dim,
|
|
num_discrete_actions = num_discrete_actions,
|
|
num_continuous_actions = num_continuous_actions,
|
|
continuous_norm_stats = continuous_norm_stats,
|
|
can_unembed = True,
|
|
unembed_dim = dim * 4,
|
|
num_unembed_preds = multi_token_pred_len,
|
|
squeeze_unembed_preds = False
|
|
)
|
|
|
|
# multi token prediction length
|
|
|
|
self.multi_token_pred_len = multi_token_pred_len
|
|
|
|
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
|
|
|
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
|
self.add_reward_embed_dropout = add_reward_embed_dropout
|
|
|
|
self.reward_encoder = SymExpTwoHot(
|
|
**reward_encoder_kwargs,
|
|
dim_embed = dim,
|
|
learned_embedding = add_reward_embed_to_agent_token
|
|
)
|
|
|
|
to_reward_pred = Sequential(
|
|
RMSNorm(dim),
|
|
LinearNoBias(dim, self.reward_encoder.num_bins)
|
|
)
|
|
|
|
self.to_reward_pred = Ensemble(
|
|
to_reward_pred,
|
|
multi_token_pred_len
|
|
)
|
|
|
|
# value head
|
|
|
|
self.value_head = create_mlp(
|
|
dim_in = dim,
|
|
dim = dim * 4,
|
|
dim_out = self.reward_encoder.num_bins,
|
|
depth = value_head_mlp_depth,
|
|
)
|
|
|
|
# efficient axial space / time transformer
|
|
|
|
self.transformer = AxialSpaceTimeTransformer(
|
|
dim = dim,
|
|
depth = depth,
|
|
attn_dim_head = attn_dim_head,
|
|
attn_softclamp_value = attn_softclamp_value,
|
|
attn_kwargs = attn_kwargs,
|
|
ff_kwargs = ff_kwargs,
|
|
num_residual_streams = num_residual_streams,
|
|
num_special_spatial_tokens = num_agents,
|
|
time_block_every = time_block_every,
|
|
final_norm = False,
|
|
**transformer_kwargs
|
|
)
|
|
|
|
# ppo related
|
|
|
|
self.gae_use_accelerated = gae_use_accelerated
|
|
self.gae_discount_factor = gae_discount_factor
|
|
self.gae_lambda = gae_lambda
|
|
|
|
self.ppo_eps_clip = ppo_eps_clip
|
|
self.value_clip = value_clip
|
|
self.policy_entropy_weight = value_clip
|
|
|
|
# rewards related
|
|
|
|
self.keep_reward_ema_stats = keep_reward_ema_stats
|
|
self.reward_ema_decay = reward_ema_decay
|
|
|
|
self.register_buffer('reward_quantile_filter', tensor(reward_quantile_filter), persistent = False)
|
|
|
|
self.register_buffer('ema_returns_mean', tensor(0.))
|
|
self.register_buffer('ema_returns_var', tensor(1.))
|
|
|
|
# loss related
|
|
|
|
self.flow_loss_normalizer = LossNormalizer(1)
|
|
self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
|
|
self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
|
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
|
|
|
self.latent_flow_loss_weight = latent_flow_loss_weight
|
|
|
|
self.register_buffer('reward_loss_weight', tensor(reward_loss_weight))
|
|
self.register_buffer('discrete_action_loss_weight', tensor(discrete_action_loss_weight))
|
|
self.register_buffer('continuous_action_loss_weight', tensor(continuous_action_loss_weight))
|
|
|
|
assert self.reward_loss_weight.numel() in {1, multi_token_pred_len}
|
|
assert self.discrete_action_loss_weight.numel() in {1, multi_token_pred_len}
|
|
assert self.continuous_action_loss_weight.numel() in {1, multi_token_pred_len}
|
|
|
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
|
|
@property
|
|
def device(self):
|
|
return self.zero.device
|
|
|
|
# types of parameters
|
|
|
|
def muon_parameters(self):
|
|
return self.transformer.muon_parameters()
|
|
|
|
def policy_head_parameters(self):
|
|
return [
|
|
*self.policy_head.parameters(),
|
|
*self.action_embedder.unembed_parameters() # includes the unembed from the action-embedder
|
|
]
|
|
|
|
def value_head_parameters(self):
|
|
return self.value_head.parameters()
|
|
|
|
def parameter(self):
|
|
params = super().parameters()
|
|
|
|
if not exists(self.video_tokenizer):
|
|
return params
|
|
|
|
return list(set(params) - set(self.video_tokenizer.parameters()))
|
|
|
|
# helpers for shortcut flow matching
|
|
|
|
def get_times_from_signal_level(
|
|
self,
|
|
signal_levels,
|
|
align_dims_left_to = None
|
|
):
|
|
times = signal_levels.float() / self.max_steps
|
|
|
|
if not exists(align_dims_left_to):
|
|
return times
|
|
|
|
return align_dims_left(times, align_dims_left_to)
|
|
|
|
# interacting with env for experience
|
|
|
|
@torch.no_grad()
|
|
def interact_with_env(
|
|
self,
|
|
env,
|
|
seed = None,
|
|
agent_index = 0,
|
|
step_size = 4,
|
|
max_timesteps = 16,
|
|
env_is_vectorized = False,
|
|
use_time_kv_cache = True,
|
|
store_agent_embed = False
|
|
):
|
|
assert exists(self.video_tokenizer)
|
|
|
|
init_frame = env.reset()
|
|
|
|
# frame to video
|
|
|
|
if env_is_vectorized:
|
|
video = rearrange(init_frame, 'b c vh vw -> b c 1 vh vw')
|
|
else:
|
|
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
|
|
|
|
batch, device = video.shape[0], video.device
|
|
|
|
# accumulate
|
|
|
|
rewards = None
|
|
discrete_actions = None
|
|
continuous_actions = None
|
|
discrete_log_probs = None
|
|
continuous_log_probs = None
|
|
values = None
|
|
latents = None
|
|
|
|
acc_agent_embed = None
|
|
|
|
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
|
|
|
is_terminated = full((batch,), False, device = device)
|
|
is_truncated = full((batch,), False, device = device)
|
|
|
|
episode_lens = full((batch,), 0, device = device)
|
|
|
|
# maybe time kv cache
|
|
|
|
time_kv_cache = None
|
|
|
|
step_index = 0
|
|
|
|
while not is_terminated.all():
|
|
step_index += 1
|
|
|
|
latents = self.video_tokenizer(video, return_latents = True)
|
|
|
|
_, (agent_embed, next_time_kv_cache) = self.forward(
|
|
latents = latents,
|
|
signal_levels = self.max_steps - 1,
|
|
step_sizes = step_size,
|
|
rewards = rewards,
|
|
discrete_actions = discrete_actions,
|
|
continuous_actions = continuous_actions,
|
|
time_kv_cache = time_kv_cache,
|
|
latent_is_noised = True,
|
|
return_pred_only = True,
|
|
return_intermediates = True
|
|
)
|
|
|
|
# time kv cache
|
|
|
|
if use_time_kv_cache:
|
|
time_kv_cache = next_time_kv_cache
|
|
|
|
# get one agent
|
|
|
|
one_agent_embed = agent_embed[..., -1:, agent_index, :]
|
|
|
|
# values
|
|
|
|
value_bins = self.value_head(one_agent_embed)
|
|
value = self.reward_encoder.bins_to_scalar_value(value_bins)
|
|
|
|
values = safe_cat((values, value), dim = 1)
|
|
|
|
# policy embed
|
|
|
|
policy_embed = self.policy_head(one_agent_embed)
|
|
|
|
# sample actions
|
|
|
|
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
|
|
|
discrete_actions = safe_cat((discrete_actions, sampled_discrete_actions), dim = 1)
|
|
continuous_actions = safe_cat((continuous_actions, sampled_continuous_actions), dim = 1)
|
|
|
|
# get the log prob and values for policy optimization
|
|
|
|
one_discrete_log_probs, one_continuous_log_probs = self.action_embedder.log_probs(
|
|
policy_embed,
|
|
pred_head_index = 0,
|
|
discrete_targets = sampled_discrete_actions,
|
|
continuous_targets = sampled_continuous_actions,
|
|
)
|
|
|
|
discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
|
|
continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
|
|
|
|
# pass the sampled action to the environment and get back next state and reward
|
|
|
|
env_step_out = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
|
|
|
if len(env_step_out) == 2:
|
|
next_frame, reward = env_step_out
|
|
terminated = full((batch,), False)
|
|
truncated = full((batch,), False)
|
|
|
|
elif len(env_step_out) == 3:
|
|
next_frame, reward, terminated = env_step_out
|
|
truncated = full((batch,), False)
|
|
|
|
elif len(env_step_out) == 4:
|
|
next_frame, reward, terminated, truncated = env_step_out
|
|
|
|
# update episode lens
|
|
|
|
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
|
|
|
# update `is_terminated`
|
|
|
|
# (1) - environment says it is terminated
|
|
# (2) - previous step is truncated (this step is for bootstrap value)
|
|
|
|
is_terminated |= (terminated | is_truncated)
|
|
|
|
# update `is_truncated`
|
|
|
|
if step_index <= max_timesteps:
|
|
is_truncated |= truncated
|
|
|
|
if step_index == max_timesteps:
|
|
# if the step index is at the max time step allowed, set the truncated flag, if not already terminated
|
|
|
|
is_truncated |= ~is_terminated
|
|
|
|
# batch and time dimension
|
|
|
|
if env_is_vectorized:
|
|
next_frame = rearrange(next_frame, 'b c vh vw -> b c 1 vh vw')
|
|
reward = rearrange(reward, 'b -> b 1')
|
|
else:
|
|
next_frame = rearrange(next_frame, 'c vh vw -> 1 c 1 vh vw')
|
|
reward = rearrange(reward, ' -> 1 1')
|
|
|
|
# concat
|
|
|
|
video = cat((video, next_frame), dim = 2)
|
|
rewards = safe_cat((rewards, reward), dim = 1)
|
|
|
|
acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
|
|
|
|
# package up one experience for learning
|
|
|
|
batch, device = latents.shape[0], latents.device
|
|
|
|
one_experience = Experience(
|
|
latents = latents,
|
|
video = video[:, :, :-1],
|
|
rewards = rewards,
|
|
actions = (discrete_actions, continuous_actions),
|
|
log_probs = (discrete_log_probs, continuous_log_probs),
|
|
values = values,
|
|
agent_embed = acc_agent_embed if store_agent_embed else None,
|
|
step_size = step_size,
|
|
agent_index = agent_index,
|
|
is_truncated = is_truncated,
|
|
lens = episode_lens,
|
|
is_from_world_model = False
|
|
)
|
|
|
|
return one_experience
|
|
|
|
# ppo
|
|
|
|
def learn_from_experience(
|
|
self,
|
|
experience: Experience,
|
|
policy_optim: Optimizer | None = None,
|
|
value_optim: Optimizer | None = None,
|
|
only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
|
use_signed_advantage = True,
|
|
eps = 1e-6
|
|
):
|
|
|
|
latents = experience.latents
|
|
actions = experience.actions
|
|
old_log_probs = experience.log_probs
|
|
old_values = experience.values
|
|
rewards = experience.rewards
|
|
|
|
step_size = experience.step_size
|
|
agent_index = experience.agent_index
|
|
|
|
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
|
|
|
|
batch, time = latents.shape[0], latents.shape[1]
|
|
|
|
# calculate returns
|
|
|
|
# mask out anything after the `lens`, which may include a bootstrapped node at the very end if `is_truncated = True`
|
|
|
|
if not exists(experience.is_truncated):
|
|
experience.is_truncated = full((batch,), True, device = latents.device)
|
|
|
|
if exists(experience.lens):
|
|
mask_for_gae = lens_to_mask(experience.lens, time)
|
|
|
|
rewards = rewards.masked_fill(mask_for_gae, 0.)
|
|
old_values = old_values.masked_fill(mask_for_gae, 0.)
|
|
|
|
# calculate returns
|
|
|
|
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
|
|
|
# handle variable lengths
|
|
|
|
max_time = latents.shape[1]
|
|
is_var_len = exists(experience.lens)
|
|
|
|
if is_var_len:
|
|
learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
|
|
mask = lens_to_mask(learnable_lens, max_time)
|
|
|
|
# determine whether to finetune entire transformer or just learn the heads
|
|
|
|
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
|
|
|
|
# maybe keep track returns statistics and normalize returns and values before calculating advantage, as done in dreamer v3
|
|
|
|
if self.keep_reward_ema_stats:
|
|
ema_returns_mean, ema_returns_var = self.ema_returns_mean, self.ema_returns_var
|
|
|
|
decay = 1. - self.reward_ema_decay
|
|
|
|
# quantile filter
|
|
|
|
lo, hi = torch.quantile(returns, self.reward_quantile_filter).tolist()
|
|
returns_for_stats = returns.clamp(lo, hi)
|
|
|
|
# mean, var - todo - handle distributed
|
|
|
|
returns_mean, returns_var = returns.mean(), returns.var()
|
|
|
|
# ema
|
|
|
|
ema_returns_mean.lerp_(returns_mean, decay)
|
|
ema_returns_var.lerp_(returns_var, decay)
|
|
|
|
# normalize
|
|
|
|
ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
|
|
|
|
normed_returns = (returns - ema_returns_mean) / ema_returns_std
|
|
normed_old_values = (old_values - ema_returns_mean) / ema_returns_std
|
|
|
|
advantage = normed_returns - normed_old_values
|
|
else:
|
|
advantage = returns - old_values
|
|
|
|
# apparently they just use the sign of the advantage
|
|
# https://arxiv.org/abs/2410.04166v1
|
|
|
|
if use_signed_advantage:
|
|
advantage = advantage.sign()
|
|
else:
|
|
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
|
|
|
# replay for the action logits and values
|
|
|
|
discrete_actions, continuous_actions = actions
|
|
|
|
with world_model_forward_context():
|
|
_, (agent_embed, _) = self.forward(
|
|
latents = latents,
|
|
signal_levels = self.max_steps - 1,
|
|
step_sizes = step_size,
|
|
rewards = rewards,
|
|
discrete_actions = discrete_actions,
|
|
continuous_actions = continuous_actions,
|
|
latent_is_noised = True,
|
|
return_pred_only = True,
|
|
return_intermediates = True
|
|
)
|
|
|
|
agent_embed = agent_embed[..., agent_index, :]
|
|
|
|
# maybe detach agent embed
|
|
|
|
if only_learn_policy_value_heads:
|
|
agent_embed = agent_embed.detach()
|
|
|
|
# ppo
|
|
|
|
policy_embed = self.policy_head(agent_embed)
|
|
|
|
log_probs, entropies = self.action_embedder.log_probs(policy_embed, pred_head_index = 0, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
|
|
|
|
# concat discrete and continuous actions into one for optimizing
|
|
|
|
old_log_probs = safe_cat(old_log_probs, dim = -1)
|
|
log_probs = safe_cat(log_probs, dim = -1)
|
|
entropies = safe_cat(entropies, dim = -1)
|
|
|
|
ratio = (log_probs - old_log_probs).exp()
|
|
|
|
clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
|
|
|
|
advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions
|
|
|
|
# clipped surrogate loss
|
|
|
|
policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
|
|
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
|
|
|
|
policy_loss = policy_loss.mean()
|
|
|
|
# handle entropy loss for naive exploration bonus
|
|
|
|
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
|
|
|
|
total_policy_loss = (
|
|
policy_loss +
|
|
entropy_loss * self.policy_entropy_weight
|
|
)
|
|
|
|
# maybe handle variable lengths
|
|
|
|
if is_var_len:
|
|
total_policy_loss = total_policy_loss[mask].mean()
|
|
else:
|
|
total_policy_loss = total_policy_loss.mean()
|
|
|
|
# maybe take policy optimizer step
|
|
|
|
if exists(policy_optim):
|
|
total_policy_loss.backward()
|
|
|
|
policy_optim.step()
|
|
policy_optim.zero_grad()
|
|
|
|
# value loss
|
|
|
|
value_bins = self.value_head(agent_embed)
|
|
values = self.reward_encoder.bins_to_scalar_value(value_bins)
|
|
|
|
clipped_values = old_values + (values - old_values).clamp(-self.value_clip, self.value_clip)
|
|
clipped_value_bins = self.reward_encoder(clipped_values)
|
|
|
|
return_bins = self.reward_encoder(returns)
|
|
|
|
value_bins, return_bins, clipped_value_bins = tuple(rearrange(t, 'b t l -> b l t') for t in (value_bins, return_bins, clipped_value_bins))
|
|
|
|
value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
|
|
value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
|
|
|
|
value_loss = torch.maximum(value_loss_1, value_loss_2)
|
|
|
|
# maybe variable length
|
|
|
|
if is_var_len:
|
|
value_loss = value_loss[mask].mean()
|
|
else:
|
|
value_loss = value_loss.mean()
|
|
|
|
# maybe take value optimizer step
|
|
|
|
if exists(policy_optim):
|
|
value_loss.backward()
|
|
|
|
value_optim.step()
|
|
value_optim.zero_grad()
|
|
|
|
return total_policy_loss, value_loss
|
|
|
|
@torch.no_grad()
|
|
def generate(
|
|
self,
|
|
time_steps,
|
|
num_steps = 4,
|
|
batch_size = 1,
|
|
agent_index = 0,
|
|
tasks: int | Tensor | None = None,
|
|
image_height = None,
|
|
image_width = None,
|
|
return_decoded_video = None,
|
|
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
|
time_kv_cache: Tensor | None = None,
|
|
use_time_kv_cache = True,
|
|
return_rewards_per_frame = False,
|
|
return_agent_actions = False,
|
|
return_log_probs_and_values = False,
|
|
return_time_kv_cache = False,
|
|
store_agent_embed = False
|
|
|
|
): # (b t n d) | (b c t h w)
|
|
|
|
has_proprio = self.has_proprio
|
|
was_training = self.training
|
|
self.eval()
|
|
|
|
# validation
|
|
|
|
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
|
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
|
|
|
if isinstance(tasks, int):
|
|
tasks = full((batch_size,), tasks, device = self.device)
|
|
|
|
assert not exists(tasks) or tasks.shape[0] == batch_size
|
|
|
|
# get state latent shape
|
|
|
|
latent_shape = self.latent_shape
|
|
|
|
# derive step size
|
|
|
|
step_size = self.max_steps // num_steps
|
|
|
|
# denoising
|
|
# teacher forcing to start with
|
|
|
|
latents = empty((batch_size, 0, self.num_video_views, *latent_shape), device = self.device)
|
|
|
|
past_latents_context_noise = latents.clone()
|
|
|
|
# maybe internal state
|
|
|
|
if has_proprio:
|
|
proprio = empty((batch_size, 0, self.dim_proprio), device = self.device)
|
|
|
|
past_proprio_context_noise = proprio.clone()
|
|
|
|
# maybe return actions
|
|
|
|
return_agent_actions |= return_log_probs_and_values
|
|
|
|
decoded_discrete_actions = None
|
|
decoded_continuous_actions = None
|
|
|
|
# policy optimization related
|
|
|
|
decoded_discrete_log_probs = None
|
|
decoded_continuous_log_probs = None
|
|
decoded_values = None
|
|
|
|
# maybe store agent embed
|
|
|
|
acc_agent_embed = None
|
|
|
|
# maybe return rewards
|
|
|
|
decoded_rewards = None
|
|
if return_rewards_per_frame:
|
|
decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32)
|
|
|
|
# while all the frames of the video (per latent) is not generated
|
|
|
|
while latents.shape[1] < time_steps:
|
|
|
|
curr_time_steps = latents.shape[1]
|
|
|
|
noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
|
|
|
|
noised_proprio = None
|
|
|
|
if has_proprio:
|
|
noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device)
|
|
|
|
for step in range(num_steps):
|
|
is_last_step = (step + 1) == num_steps
|
|
|
|
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
|
|
|
# noising past latent context
|
|
|
|
noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8)
|
|
|
|
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * v n d')
|
|
|
|
# handle proprio
|
|
|
|
noised_proprio_with_context = None
|
|
|
|
if has_proprio:
|
|
noised_proprio_context = proprio.lerp(past_proprio_context_noise, context_signal_noise)
|
|
noised_proprio_with_context, _ = pack((noised_proprio_context, noised_proprio), 'b * d')
|
|
|
|
# proper signal levels
|
|
|
|
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
|
|
|
pred, (agent_embed, next_time_kv_cache) = self.forward(
|
|
latents = noised_latent_with_context,
|
|
signal_levels = signal_levels_with_context,
|
|
step_sizes = step_size,
|
|
rewards = decoded_rewards,
|
|
tasks = tasks,
|
|
discrete_actions = decoded_discrete_actions,
|
|
continuous_actions = decoded_continuous_actions,
|
|
proprio = noised_proprio_with_context,
|
|
time_kv_cache = time_kv_cache,
|
|
latent_is_noised = True,
|
|
latent_has_view_dim = True,
|
|
return_pred_only = True,
|
|
return_intermediates = True,
|
|
)
|
|
|
|
if use_time_kv_cache and is_last_step:
|
|
time_kv_cache = next_time_kv_cache
|
|
|
|
# maybe proprio
|
|
|
|
if has_proprio:
|
|
pred, pred_proprio = pred
|
|
|
|
# unpack pred
|
|
|
|
_, pred = unpack(pred, pack_context_shape, 'b * v n d')
|
|
|
|
if has_proprio:
|
|
_, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d')
|
|
|
|
# derive flow, based on whether in x-space or not
|
|
|
|
def denoise_step(pred, noised, signal_levels):
|
|
if self.pred_orig_latent:
|
|
times = self.get_times_from_signal_level(signal_levels)
|
|
aligned_times = align_dims_left(times, noised)
|
|
|
|
flow = (pred - noised) / (1. - aligned_times)
|
|
else:
|
|
flow = pred
|
|
|
|
return flow * (step_size / self.max_steps)
|
|
|
|
# denoise
|
|
|
|
noised_latent += denoise_step(pred, noised_latent, signal_levels)
|
|
|
|
if has_proprio:
|
|
noised_proprio += denoise_step(pred_proprio, noised_proprio, signal_levels)
|
|
|
|
denoised_latent = noised_latent # it is now denoised
|
|
|
|
if has_proprio:
|
|
denoised_proprio = noised_proprio
|
|
|
|
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
|
|
|
if return_rewards_per_frame:
|
|
one_agent_embed = agent_embed[:, -1:, agent_index]
|
|
|
|
reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0)
|
|
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
|
|
|
|
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
|
|
|
# maybe store agent embed
|
|
|
|
acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
|
|
|
|
# decode the agent actions if needed
|
|
|
|
if return_agent_actions:
|
|
assert self.action_embedder.has_actions
|
|
|
|
one_agent_embed = agent_embed[:, -1:, agent_index]
|
|
|
|
policy_embed = self.policy_head(one_agent_embed)
|
|
|
|
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
|
|
|
decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
|
|
decoded_continuous_actions = safe_cat((decoded_continuous_actions, sampled_continuous_actions), dim = 1)
|
|
|
|
if return_log_probs_and_values:
|
|
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
|
|
policy_embed,
|
|
pred_head_index = 0,
|
|
discrete_targets = sampled_discrete_actions,
|
|
continuous_targets = sampled_continuous_actions,
|
|
)
|
|
|
|
decoded_discrete_log_probs = safe_cat((decoded_discrete_log_probs, discrete_log_probs), dim = 1)
|
|
decoded_continuous_log_probs = safe_cat((decoded_continuous_log_probs, continuous_log_probs), dim = 1)
|
|
|
|
value_bins = self.value_head(one_agent_embed)
|
|
values = self.reward_encoder.bins_to_scalar_value(value_bins)
|
|
|
|
decoded_values = safe_cat((decoded_values, values), dim = 1)
|
|
|
|
# concat the denoised latent
|
|
|
|
latents = cat((latents, denoised_latent), dim = 1)
|
|
|
|
# add new fixed context noise for the temporal consistency
|
|
|
|
past_latents_context_noise = cat((past_latents_context_noise, randn_like(denoised_latent)), dim = 1)
|
|
|
|
# handle proprio
|
|
|
|
if has_proprio:
|
|
proprio = cat((proprio, denoised_proprio), dim = 1)
|
|
|
|
past_proprio_context_noise = cat((past_proprio_context_noise, randn_like(denoised_proprio)), dim = 1)
|
|
|
|
# restore state
|
|
|
|
self.train(was_training)
|
|
|
|
# returning video
|
|
|
|
has_tokenizer = exists(self.video_tokenizer)
|
|
return_decoded_video = default(return_decoded_video, has_tokenizer)
|
|
|
|
video = None
|
|
|
|
if return_decoded_video:
|
|
|
|
latents_for_video = rearrange(latents, 'b t v n d -> b v t n d')
|
|
latents_for_video, unpack_view = pack_one(latents_for_video, '* t n d')
|
|
|
|
video = self.video_tokenizer.decode(
|
|
latents_for_video,
|
|
height = image_height,
|
|
width = image_width
|
|
)
|
|
|
|
video = unpack_view(video, '* t c vh vw')
|
|
|
|
# remove the lone view dimension
|
|
|
|
if not self.video_has_multi_view:
|
|
latents = rearrange(latents, 'b t 1 ... -> b t ...')
|
|
|
|
if exists(video):
|
|
video = rearrange(video, 'b 1 ... -> b ...')
|
|
|
|
# only return video or latent if not requesting anything else, for first stage training
|
|
|
|
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
|
|
out = video if return_decoded_video else latents
|
|
|
|
if not return_time_kv_cache:
|
|
return out
|
|
|
|
return out, time_kv_cache
|
|
|
|
# returning agent actions, rewards, and log probs + values for policy optimization
|
|
|
|
batch, device = latents.shape[0], latents.device
|
|
experience_lens = full((batch,), time_steps, device = device)
|
|
|
|
gen = Experience(
|
|
latents = latents,
|
|
video = video,
|
|
proprio = proprio if has_proprio else None,
|
|
agent_embed = acc_agent_embed if store_agent_embed else None,
|
|
step_size = step_size,
|
|
agent_index = agent_index,
|
|
lens = experience_lens,
|
|
is_from_world_model = True
|
|
)
|
|
|
|
if return_rewards_per_frame:
|
|
gen.rewards = decoded_rewards
|
|
|
|
if return_agent_actions:
|
|
gen.actions = (decoded_discrete_actions, decoded_continuous_actions)
|
|
|
|
if return_log_probs_and_values:
|
|
gen.log_probs = (decoded_discrete_log_probs, decoded_continuous_log_probs)
|
|
|
|
gen.values = decoded_values
|
|
|
|
if not return_time_kv_cache:
|
|
return gen
|
|
|
|
return gen, time_kv_cache
|
|
|
|
def forward(
|
|
self,
|
|
*,
|
|
video = None, # (b v? c t vh vw)
|
|
latents = None, # (b t v? n d) | (b t v? d)
|
|
lens = None, # (b)
|
|
signal_levels = None, # () | (b) | (b t)
|
|
step_sizes = None, # () | (b)
|
|
step_sizes_log2 = None, # () | (b)
|
|
latent_gene_ids = None, # (b)
|
|
tasks = None, # (b)
|
|
rewards = None, # (b t)
|
|
discrete_actions = None, # (b t na) | (b t-1 na)
|
|
continuous_actions = None, # (b t na) | (b t-1 na)
|
|
discrete_action_types = None, # (na)
|
|
continuous_action_types = None, # (na)
|
|
proprio = None, # (b t dp)
|
|
time_kv_cache = None,
|
|
return_pred_only = False,
|
|
latent_is_noised = False,
|
|
return_all_losses = False,
|
|
return_intermediates = False,
|
|
add_autoregressive_action_loss = False,
|
|
update_loss_ema = None,
|
|
latent_has_view_dim = False
|
|
):
|
|
# handle video or latents
|
|
|
|
assert exists(video) ^ exists(latents)
|
|
|
|
# standardize view dimension
|
|
|
|
if not self.video_has_multi_view:
|
|
if exists(video):
|
|
video = rearrange(video, 'b ... -> b 1 ...')
|
|
|
|
if exists(latents) and not latent_has_view_dim:
|
|
latents = rearrange(latents, 'b t ... -> b t 1 ...')
|
|
|
|
# if raw video passed in, tokenize
|
|
|
|
if exists(video):
|
|
assert video.ndim == 6
|
|
|
|
video, unpack_views = pack_one(video, '* c t vh vw')
|
|
assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
|
|
|
|
latents = self.video_tokenizer.tokenize(video)
|
|
latents = unpack_views(latents, '* t n d')
|
|
latents = rearrange(latents, 'b v t n d -> b t v n d')
|
|
|
|
if latents.ndim == 4:
|
|
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
|
|
|
|
assert latents.shape[-2:] == self.latent_shape
|
|
assert latents.shape[2] == self.num_video_views
|
|
|
|
# variables
|
|
|
|
batch, time, device = *latents.shape[:2], latents.device
|
|
|
|
# signal and step size related input conforming
|
|
|
|
if exists(signal_levels):
|
|
if isinstance(signal_levels, int):
|
|
signal_levels = tensor(signal_levels, device = self.device)
|
|
|
|
if signal_levels.ndim == 0:
|
|
signal_levels = repeat(signal_levels, '-> b', b = batch)
|
|
|
|
if signal_levels.ndim == 1:
|
|
signal_levels = repeat(signal_levels, 'b -> b t', t = time)
|
|
|
|
if exists(step_sizes):
|
|
if isinstance(step_sizes, int):
|
|
step_sizes = tensor(step_sizes, device = self.device)
|
|
|
|
if step_sizes.ndim == 0:
|
|
step_sizes = repeat(step_sizes, '-> b', b = batch)
|
|
|
|
if exists(step_sizes_log2):
|
|
if isinstance(step_sizes_log2, int):
|
|
step_sizes_log2 = tensor(step_sizes_log2, device = self.device)
|
|
|
|
if step_sizes_log2.ndim == 0:
|
|
step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch)
|
|
|
|
# handle step sizes -> step size log2
|
|
|
|
assert not (exists(step_sizes) and exists(step_sizes_log2))
|
|
|
|
if exists(step_sizes):
|
|
step_sizes_log2_maybe_float = torch.log2(step_sizes)
|
|
step_sizes_log2 = step_sizes_log2_maybe_float.long()
|
|
|
|
assert (step_sizes_log2 == step_sizes_log2_maybe_float).all(), f'`step_sizes` must be powers of 2'
|
|
|
|
# flow related
|
|
|
|
assert not (exists(signal_levels) ^ exists(step_sizes_log2))
|
|
|
|
is_inference = exists(signal_levels)
|
|
no_shortcut_train = not is_inference
|
|
|
|
return_pred_only = return_pred_only or latent_is_noised
|
|
|
|
# if neither signal levels or step sizes passed in, assume training
|
|
# generate them randomly for training
|
|
|
|
if not is_inference:
|
|
|
|
no_shortcut_train = sample_prob(self.prob_no_shortcut_train)
|
|
|
|
if no_shortcut_train:
|
|
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
|
|
# in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this
|
|
|
|
step_sizes_log2 = zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
|
|
signal_levels = randint(0, self.max_steps, (batch, time), device = device)
|
|
else:
|
|
|
|
# now we follow eq (4)
|
|
|
|
step_sizes_log2 = randint(1, self.num_step_sizes_log2, (batch,), device = device)
|
|
num_step_sizes = 2 ** step_sizes_log2
|
|
|
|
signal_levels = randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
|
|
|
|
# times is from 0 to 1
|
|
|
|
times = self.get_times_from_signal_level(signal_levels)
|
|
|
|
if not latent_is_noised:
|
|
# get the noise
|
|
|
|
noise = randn_like(latents)
|
|
aligned_times = align_dims_left(times, latents)
|
|
|
|
# noise from 0 as noise to 1 as data
|
|
|
|
noised_latents = noise.lerp(latents, aligned_times)
|
|
|
|
else:
|
|
noised_latents = latents
|
|
|
|
# reinforcement learning related
|
|
|
|
agent_tokens = repeat(self.agent_learned_embed, '... d -> b ... d', b = batch)
|
|
|
|
if exists(tasks):
|
|
assert self.num_tasks > 0
|
|
|
|
task_embeds = self.task_embed(tasks)
|
|
agent_tokens = add('b ... d, b d', agent_tokens, task_embeds)
|
|
|
|
# maybe evolution
|
|
|
|
if exists(latent_gene_ids):
|
|
assert exists(self.latent_genes)
|
|
latent_genes = self.latent_genes[latent_gene_ids]
|
|
|
|
agent_tokens = add('b ... d, b d', agent_tokens, latent_genes)
|
|
|
|
# handle agent tokens w/ actions and task embeds
|
|
|
|
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
|
|
|
# maybe reward tokens
|
|
|
|
reward_tokens = agent_tokens[:, :, 0:0]
|
|
|
|
if exists(rewards):
|
|
two_hot_encoding = self.reward_encoder(rewards)
|
|
|
|
if (
|
|
self.add_reward_embed_to_agent_token and
|
|
(not self.training or not sample_prob(self.add_reward_embed_dropout)) # a bit of noise goes a long way
|
|
):
|
|
assert self.num_agents == 1
|
|
|
|
reward_tokens = self.reward_encoder.embed(two_hot_encoding)
|
|
|
|
pop_last_reward = int(reward_tokens.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case
|
|
|
|
reward_tokens = pad_at_dim(reward_tokens, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
|
|
|
|
reward_tokens = add('1 d, b t d', self.reward_learned_embed, reward_tokens)
|
|
|
|
# maybe proprioception
|
|
|
|
assert xnor(self.has_proprio, exists(proprio)), 'proprio must be passed in if `dim_proprio` is set and vice versa'
|
|
|
|
noised_proprio = None
|
|
|
|
if self.has_proprio:
|
|
|
|
if not latent_is_noised:
|
|
# get the noise
|
|
|
|
proprio_noise = randn_like(proprio)
|
|
aligned_times = align_dims_left(times, proprio)
|
|
|
|
# noise from 0 as noise to 1 as data
|
|
|
|
noised_proprio = proprio_noise.lerp(proprio, aligned_times)
|
|
|
|
else:
|
|
noised_proprio = proprio
|
|
|
|
# maybe create the action tokens
|
|
|
|
if exists(discrete_actions) or exists(continuous_actions):
|
|
assert self.action_embedder.has_actions
|
|
assert self.num_agents == 1, 'only one agent allowed for now'
|
|
|
|
action_tokens = self.action_embedder(
|
|
discrete_actions = discrete_actions,
|
|
discrete_action_types = discrete_action_types,
|
|
continuous_actions = continuous_actions,
|
|
continuous_action_types = continuous_action_types
|
|
)
|
|
|
|
# handle first timestep not having an associated past action
|
|
|
|
if action_tokens.shape[1] == (time - 1):
|
|
action_tokens = pad_at_dim(action_tokens, (1, 0), value = 0. , dim = 1)
|
|
|
|
action_tokens = add('1 d, b t d', self.action_learned_embed, action_tokens)
|
|
|
|
elif self.action_embedder.has_actions:
|
|
action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1])
|
|
|
|
else:
|
|
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
|
|
|
|
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
|
|
|
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
|
|
|
|
# latents to spatial tokens
|
|
|
|
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
|
|
|
# maybe add view embedding
|
|
|
|
if self.video_has_multi_view:
|
|
space_tokens = add('b t v ... d, v d', space_tokens, self.view_emb)
|
|
|
|
# merge spatial tokens
|
|
|
|
space_tokens, inverse_pack_space_per_latent = pack_one(space_tokens, 'b t * d')
|
|
|
|
num_spatial_tokens = space_tokens.shape[-2]
|
|
|
|
# action tokens
|
|
|
|
num_action_tokens = 1 if not is_empty(action_tokens) else 0
|
|
|
|
# reward tokens
|
|
|
|
num_reward_tokens = 1 if not is_empty(reward_tokens) else 0
|
|
|
|
# pack to tokens
|
|
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
|
|
|
|
registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
|
|
|
|
# maybe proprio
|
|
|
|
if exists(noised_proprio):
|
|
proprio_token = self.to_proprio_token(noised_proprio)
|
|
else:
|
|
proprio_token = registers[:, :, 0:0]
|
|
|
|
# determine signal + step size embed for their diffusion forcing + shortcut
|
|
|
|
signal_embed = self.signal_levels_embed(signal_levels)
|
|
|
|
step_size_embed = self.step_size_embed(step_sizes_log2)
|
|
step_size_embed = repeat(step_size_embed, 'b ... -> b t ...', t = time)
|
|
|
|
flow_token = cat((signal_embed, step_size_embed), dim = -1)
|
|
flow_token = rearrange(flow_token, 'b t d -> b t d')
|
|
|
|
# pack to tokens for attending
|
|
|
|
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
|
|
|
# attention
|
|
|
|
tokens, next_time_kv_cache = self.transformer(tokens, kv_cache = time_kv_cache, return_kv_cache = True)
|
|
|
|
# unpack
|
|
|
|
flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
|
|
|
# pooling
|
|
|
|
space_tokens = inverse_pack_space_per_latent(space_tokens)
|
|
|
|
pred = self.to_latent_pred(space_tokens)
|
|
|
|
# maybe proprio
|
|
|
|
if self.has_proprio:
|
|
pred_proprio = self.to_proprio_pred(proprio_token)
|
|
|
|
pred = (pred, pred_proprio)
|
|
|
|
# returning
|
|
|
|
if not return_agent_tokens:
|
|
return pred
|
|
|
|
if not return_time_kv_cache:
|
|
return pred, agent_tokens
|
|
|
|
return pred, (agent_tokens, next_time_kv_cache)
|
|
|
|
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
|
|
|
_get_prediction = partial(get_prediction, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
|
|
|
|
# forward the network
|
|
|
|
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
|
|
|
|
if return_pred_only:
|
|
if not return_intermediates:
|
|
return pred
|
|
|
|
return pred, (encoded_agent_tokens, next_time_kv_cache)
|
|
|
|
# pack the predictions to calculate flow for different modalities all at once
|
|
|
|
if self.has_proprio:
|
|
pred, for_flow_loss_packed_shape = pack(pred, 'b t *')
|
|
|
|
noised, _ = pack((noised_latents, noised_proprio), 'b t *')
|
|
data, _ = pack((latents, proprio), 'b t *')
|
|
noise, _ = pack((noise, proprio_noise), 'b t *')
|
|
else:
|
|
noised = noised_latents
|
|
data = latents
|
|
|
|
# wrapper function for maybe unpacking and packing modalities for doing flow math in unison
|
|
|
|
def maybe_pack_unpack(fn):
|
|
@wraps(fn)
|
|
@torch.no_grad()
|
|
def inner(noised, *args, **kwargs):
|
|
|
|
noised_proprio = None
|
|
|
|
if self.has_proprio:
|
|
noised, noised_proprio = unpack(noised, for_flow_loss_packed_shape, 'b t *')
|
|
|
|
pred = fn(noised, noised_proprio, *args, **kwargs)
|
|
|
|
if self.has_proprio:
|
|
pred, _ = pack(pred, 'b t *')
|
|
|
|
return pred
|
|
return inner
|
|
|
|
wrapped_get_prediction = maybe_pack_unpack(_get_prediction)
|
|
|
|
# determine the target for the loss
|
|
|
|
pred_target = None
|
|
|
|
is_x_space = self.pred_orig_latent
|
|
is_v_space_pred = not self.pred_orig_latent
|
|
|
|
maybe_shortcut_loss_weight = 1.
|
|
|
|
if no_shortcut_train:
|
|
|
|
# allow for original velocity pred
|
|
# x-space as in paper is in else clause
|
|
|
|
if is_v_space_pred:
|
|
pred_target = flow = data - noise
|
|
else:
|
|
pred_target = data
|
|
else:
|
|
# shortcut training - Frans et al. https://arxiv.org/abs/2410.12557
|
|
|
|
# basically a consistency loss where you ensure quantity of two half steps equals one step
|
|
# dreamer then makes it works for x-space with some math
|
|
|
|
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
|
half_step_size = 2 ** step_sizes_log2_minus_one
|
|
|
|
first_step_pred = wrapped_get_prediction(noised, signal_levels, step_sizes_log2_minus_one)
|
|
|
|
# first derive b'
|
|
|
|
if is_v_space_pred:
|
|
first_step_pred_flow = first_step_pred
|
|
else:
|
|
first_times = self.get_times_from_signal_level(signal_levels, noised)
|
|
|
|
first_step_pred_flow = (first_step_pred - noised) / (1. - first_times)
|
|
|
|
# take a half step
|
|
|
|
half_step_size_align_left = align_dims_left(half_step_size, noised)
|
|
|
|
denoised = noised + first_step_pred_flow * (half_step_size_align_left / self.max_steps)
|
|
|
|
# get second prediction for b''
|
|
|
|
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
|
second_step_pred = wrapped_get_prediction(denoised, signal_levels_plus_half_step, step_sizes_log2_minus_one)
|
|
|
|
if is_v_space_pred:
|
|
second_step_pred_flow = second_step_pred
|
|
else:
|
|
second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised)
|
|
second_step_pred_flow = (second_step_pred - denoised) / (1. - second_times)
|
|
|
|
# pred target is sg(b' + b'') / 2
|
|
|
|
pred_target = (first_step_pred_flow + second_step_pred_flow).detach() / 2
|
|
|
|
# need to convert x-space to v-space
|
|
|
|
if is_x_space:
|
|
pred = (pred - noised) / (1. - first_times)
|
|
maybe_shortcut_loss_weight = (1. - first_times) ** 2
|
|
|
|
# mse loss
|
|
|
|
flow_losses = F.mse_loss(pred, pred_target, reduction = 'none')
|
|
|
|
flow_losses = flow_losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
|
|
|
|
# loss weighting with their ramp function
|
|
|
|
if exists(self.loss_weight_fn):
|
|
loss_weight = self.loss_weight_fn(times)
|
|
loss_weight = align_dims_left(loss_weight, flow_losses)
|
|
|
|
flow_losses = flow_losses * loss_weight
|
|
|
|
# handle variable lengths if needed
|
|
|
|
is_var_len = exists(lens)
|
|
|
|
if is_var_len:
|
|
|
|
loss_mask = lens_to_mask(lens, time)
|
|
loss_mask_without_last = loss_mask[:, :-1]
|
|
|
|
flow_loss = flow_losses[loss_mask].mean()
|
|
|
|
else:
|
|
flow_loss = flow_losses.mean()
|
|
|
|
# now take care of the agent token losses
|
|
|
|
reward_loss = self.zero
|
|
|
|
if exists(rewards):
|
|
|
|
if rewards.ndim == 2: # (b t)
|
|
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
|
|
|
|
reward_pred = self.to_reward_pred(encoded_agent_tokens[:, :-1])
|
|
|
|
reward_pred = rearrange(reward_pred, 'mtp b t l -> b l t mtp')
|
|
|
|
reward_targets, reward_loss_mask = create_multi_token_prediction_targets(two_hot_encoding[:, :-1], self.multi_token_pred_len)
|
|
|
|
reward_targets = rearrange(reward_targets, 'b t mtp l -> b l t mtp')
|
|
|
|
reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
|
|
|
|
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
|
|
|
if is_var_len:
|
|
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
|
|
else:
|
|
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
|
|
|
# maybe autoregressive action loss
|
|
|
|
discrete_action_loss = self.zero
|
|
continuous_action_loss = self.zero
|
|
|
|
if (
|
|
self.num_agents == 1 and
|
|
add_autoregressive_action_loss and
|
|
time > 1,
|
|
(exists(discrete_actions) or exists(continuous_actions))
|
|
):
|
|
assert self.action_embedder.has_actions
|
|
|
|
# handle actions having time vs time - 1 length
|
|
# remove the first action if it is equal to time (as it would come from some agent token in the past)
|
|
|
|
if exists(discrete_actions) and discrete_actions.shape[1] == time:
|
|
discrete_actions = discrete_actions[:, 1:]
|
|
|
|
if exists(continuous_actions) and continuous_actions.shape[1] == time:
|
|
continuous_actions = continuous_actions[:, 1:]
|
|
|
|
# only for 1 agent
|
|
|
|
agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d')
|
|
policy_embed = self.policy_head(agent_tokens[:, :-1])
|
|
|
|
# constitute multi token prediction targets
|
|
|
|
discrete_action_targets = continuous_action_targets = None
|
|
|
|
if exists(discrete_actions):
|
|
discrete_action_targets, discrete_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
|
|
discrete_action_targets = rearrange(discrete_action_targets, 'b t mtp ... -> mtp b t ...')
|
|
discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
|
|
|
|
if exists(continuous_actions):
|
|
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
|
|
continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
|
|
continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
|
|
|
|
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
|
|
policy_embed,
|
|
discrete_targets = discrete_action_targets if exists(discrete_actions) else None,
|
|
continuous_targets = continuous_action_targets if exists(continuous_actions) else None
|
|
)
|
|
|
|
if exists(discrete_log_probs):
|
|
discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
|
|
|
|
if is_var_len:
|
|
discrete_action_losses = rearrange(-discrete_log_probs, 'mtp b t na -> b t na mtp')
|
|
discrete_action_loss = reduce(discrete_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
|
|
else:
|
|
discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
|
|
|
|
if exists(continuous_log_probs):
|
|
continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
|
|
|
|
if is_var_len:
|
|
continuous_action_losses = rearrange(-continuous_log_probs, 'mtp b t na -> b t na mtp')
|
|
continuous_action_loss = reduce(continuous_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
|
|
else:
|
|
continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
|
|
|
|
# handle loss normalization
|
|
|
|
losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss)
|
|
|
|
if exists(self.flow_loss_normalizer):
|
|
flow_loss = self.flow_loss_normalizer(flow_loss, update_ema = update_loss_ema)
|
|
|
|
if exists(rewards) and exists(self.reward_loss_normalizer):
|
|
reward_loss = self.reward_loss_normalizer(reward_loss, update_ema = update_loss_ema)
|
|
|
|
if exists(discrete_actions) and exists(self.discrete_actions_loss_normalizer):
|
|
discrete_action_loss = self.discrete_actions_loss_normalizer(discrete_action_loss, update_ema = update_loss_ema)
|
|
|
|
if exists(continuous_actions) and exists(self.continuous_actions_loss_normalizer):
|
|
continuous_action_loss = self.continuous_actions_loss_normalizer(continuous_action_loss, update_ema = update_loss_ema)
|
|
|
|
# gather losses - they sum across the multi token prediction steps for rewards and actions - eq (9)
|
|
|
|
total_loss = (
|
|
flow_loss * self.latent_flow_loss_weight +
|
|
(reward_loss * self.reward_loss_weight).sum() +
|
|
(discrete_action_loss * self.discrete_action_loss_weight).sum() +
|
|
(continuous_action_loss * self.continuous_action_loss_weight).sum()
|
|
)
|
|
|
|
if not return_all_losses:
|
|
return total_loss
|
|
|
|
return total_loss, losses
|