1980 lines
62 KiB
Python
1980 lines
62 KiB
Python
from __future__ import annotations
|
|
|
|
import math
|
|
from math import ceil, log2
|
|
from random import random
|
|
from collections import namedtuple
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
|
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
|
|
import torchvision
|
|
from torchvision.models import VGG16_Weights
|
|
|
|
from x_mlps_pytorch.normed_mlp import create_mlp
|
|
from x_mlps_pytorch.ensemble import Ensemble
|
|
|
|
from assoc_scan import AssocScan
|
|
|
|
# ein related
|
|
|
|
# b - batch
|
|
# n - sequence
|
|
# h - attention heads
|
|
# d - feature dimension
|
|
# f - frequencies (rotary)
|
|
# l - logit / predicted bins
|
|
# p - positions (3 for spacetime in this work)
|
|
# t - time
|
|
# na - action dimension (number of discrete and continuous actions)
|
|
# g - groups of query heads to key heads (gqa)
|
|
# vc - video channels
|
|
# vh, vw - video height and width
|
|
|
|
import einx
|
|
from einops import einsum, rearrange, repeat, reduce, pack, unpack
|
|
from einops.layers.torch import Rearrange, Reduce
|
|
|
|
# flex attention - but will make sure it works if it is not available
|
|
# may also end up crafting own custom flash attention kernel for this work
|
|
|
|
flex_attention = None
|
|
|
|
try:
|
|
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
|
if torch.cuda.is_available():
|
|
flex_attention = torch.compile(flex_attention)
|
|
except ImportError:
|
|
pass
|
|
|
|
# constants
|
|
|
|
LinearNoBias = partial(Linear, bias = False)
|
|
|
|
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
|
|
|
# helpers
|
|
|
|
def exists(v):
|
|
return v is not None
|
|
|
|
def default(v, d):
|
|
return v if exists(v) else d
|
|
|
|
def first(arr):
|
|
return arr[0]
|
|
|
|
def ensure_tuple(t):
|
|
return (t,) if not isinstance(t, tuple) else t
|
|
|
|
def divisible_by(num, den):
|
|
return (num % den) == 0
|
|
|
|
def sample_prob(prob):
|
|
return random() < prob
|
|
|
|
def is_power_two(num):
|
|
return log2(num).is_integer()
|
|
|
|
# tensor helpers
|
|
|
|
def pack_one(t, pattern):
|
|
packed, packed_shape = pack([t], pattern)
|
|
|
|
def inverse(out, inv_pattern = None):
|
|
inv_pattern = default(inv_pattern, pattern)
|
|
return first(unpack(out, packed_shape, inv_pattern))
|
|
|
|
return packed, inverse
|
|
|
|
def pad_at_dim(
|
|
t,
|
|
pad: tuple[int, int],
|
|
dim = -1,
|
|
value = 0.
|
|
):
|
|
if pad == (0, 0):
|
|
return t
|
|
|
|
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
|
zeros = ((0, 0) * dims_from_right)
|
|
return F.pad(t, (*zeros, *pad), value = value)
|
|
|
|
def align_dims_left(t, aligned_to):
|
|
shape = t.shape
|
|
num_right_dims = aligned_to.ndim - t.ndim
|
|
|
|
if num_right_dims < 0:
|
|
return
|
|
|
|
return t.reshape(*shape, *((1,) * num_right_dims))
|
|
|
|
def l2norm(t):
|
|
return F.normalize(t, dim = -1, p = 2)
|
|
|
|
def softclamp(t, value = 50.):
|
|
return (t / value).tanh() * value
|
|
|
|
# loss related
|
|
|
|
class LPIPSLoss(Module):
|
|
def __init__(
|
|
self,
|
|
vgg: Module | None = None,
|
|
vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
|
|
sampled_frames = 1
|
|
):
|
|
super().__init__()
|
|
|
|
if not exists(vgg):
|
|
vgg = torchvision.models.vgg16(weights = vgg_weights)
|
|
vgg.classifier = Sequential(*vgg.classifier[:-2])
|
|
|
|
self.vgg = [vgg]
|
|
self.sampled_frames = sampled_frames
|
|
|
|
def forward(
|
|
self,
|
|
pred,
|
|
data,
|
|
):
|
|
batch, device, is_video = pred.shape[0], pred.device, pred.ndim == 5
|
|
|
|
vgg, = self.vgg
|
|
vgg = vgg.to(data.device)
|
|
|
|
# take care of sampling random frames of the video
|
|
|
|
if is_video:
|
|
pred, data = tuple(rearrange(t, 'b c t ... -> b t c ...') for t in (pred, data))
|
|
|
|
# batch randperm
|
|
|
|
batch_randperm = randn(pred.shape[:2], device = pred.device).argsort(dim = -1)
|
|
rand_frames = batch_randperm[..., :self.sampled_frames]
|
|
|
|
batch_arange = arange(batch, device = device)
|
|
batch_arange = rearrange(batch_arange, '... -> ... 1')
|
|
|
|
pred, data = tuple(t[batch_arange, rand_frames] for t in (pred, data))
|
|
|
|
# fold sampled frames into batch
|
|
|
|
pred, data = tuple(rearrange(t, 'b t c ... -> (b t) c ...') for t in (pred, data))
|
|
|
|
pred_embed, embed = tuple(vgg(t) for t in (pred, data))
|
|
|
|
return F.mse_loss(embed, pred_embed)
|
|
|
|
def ramp_weight(times, slope = 0.9, intercept = 0.1):
|
|
# equation (8) paper, their "ramp" loss weighting
|
|
return slope * times + intercept
|
|
|
|
# reinforcement learning related
|
|
|
|
# rewards
|
|
|
|
class SymExpTwoHot(Module):
|
|
def __init__(
|
|
self,
|
|
reward_range = (-20., 20.),
|
|
num_bins = 255,
|
|
learned_embedding = False,
|
|
dim_embed = None,
|
|
):
|
|
super().__init__()
|
|
|
|
min_value, max_value = reward_range
|
|
values = linspace(min_value, max_value, num_bins)
|
|
values = values.sign() * (torch.exp(values.abs()) - 1.)
|
|
|
|
self.reward_range = reward_range
|
|
self.num_bins = num_bins
|
|
self.register_buffer('bin_values', values)
|
|
|
|
# take care of a reward embedding
|
|
# for an improvisation where agent tokens can also see the past rewards - it makes sense that this information should not be thrown out, a la Decision Transformer
|
|
|
|
self.learned_embedding = learned_embedding
|
|
|
|
if learned_embedding:
|
|
assert exists(dim_embed)
|
|
self.bin_embeds = nn.Embedding(num_bins, dim_embed)
|
|
|
|
@property
|
|
def device(self):
|
|
return self.bin_values.device
|
|
|
|
def embed(
|
|
self,
|
|
two_hot_encoding,
|
|
):
|
|
assert self.learned_embedding, f'can only embed if `learned_embedding` is True'
|
|
|
|
weights, bin_indices = two_hot_encoding.topk(k = 2, dim = -1)
|
|
|
|
two_embeds = self.bin_embeds(bin_indices)
|
|
|
|
return einsum(two_embeds, weights, '... two d, ... two -> ... d')
|
|
|
|
def bins_to_scalar_value(
|
|
self,
|
|
logits, # (... l)
|
|
normalize = False
|
|
):
|
|
two_hot_encoding = logits.softmax(dim = -1) if normalize else logits
|
|
return einsum(two_hot_encoding, self.bin_values, '... l, l -> ...')
|
|
|
|
def forward(
|
|
self,
|
|
values
|
|
):
|
|
bin_values = self.bin_values
|
|
min_bin_value, max_bin_value = self.bin_values[0], self.bin_values[-1]
|
|
|
|
values, inverse_pack = pack_one(values, '*')
|
|
num_values = values.shape[0]
|
|
|
|
values = values.clamp(min = min_bin_value, max = max_bin_value)
|
|
|
|
indices = torch.searchsorted(self.bin_values, values)
|
|
|
|
# fetch the closest two indices (two-hot encoding)
|
|
|
|
left_indices = (indices - 1).clamp(min = 0)
|
|
right_indices = left_indices + 1
|
|
|
|
left_indices, right_indices = tuple(rearrange(t, '... -> ... 1') for t in (left_indices, right_indices))
|
|
|
|
# fetch the left and right values for the consecutive indices
|
|
|
|
left_values = self.bin_values[left_indices]
|
|
right_values = self.bin_values[right_indices]
|
|
|
|
# calculate the left and right values by the distance to the left and right
|
|
|
|
values = rearrange(values, '... -> ... 1')
|
|
total_distance = right_values - left_values
|
|
|
|
left_logit_value = (right_values - values) / total_distance
|
|
right_logit_value = 1. - left_logit_value
|
|
|
|
# set the left and right values (two-hot)
|
|
|
|
encoded = torch.zeros((num_values, self.num_bins), device = self.device)
|
|
|
|
encoded.scatter_(-1, left_indices, left_logit_value)
|
|
encoded.scatter_(-1, right_indices, right_logit_value)
|
|
|
|
return inverse_pack(encoded, '* l')
|
|
|
|
# action related
|
|
|
|
ActionEmbeds = namedtuple('ActionEmbed', ('discrete', 'continuous'))
|
|
|
|
class ActionEmbedder(Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
*,
|
|
num_discrete_actions: int | tuple[int, ...] = 0,
|
|
num_continuous_actions = 0,
|
|
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None,
|
|
can_unembed = False,
|
|
unembed_dim = None
|
|
):
|
|
super().__init__()
|
|
|
|
# handle discrete actions
|
|
|
|
num_discrete_actions = tensor(ensure_tuple(num_discrete_actions))
|
|
total_discrete_actions = num_discrete_actions.sum().item()
|
|
|
|
self.num_discrete_action_types = len(num_discrete_actions)
|
|
self.discrete_action_embed = Embedding(total_discrete_actions, dim)
|
|
|
|
self.register_buffer('num_discrete_actions', num_discrete_actions, persistent = False)
|
|
|
|
# continuous actions
|
|
|
|
self.num_continuous_action_types = num_continuous_actions
|
|
self.continuous_action_embed = Embedding(num_continuous_actions, dim)
|
|
|
|
self.continuous_need_norm = exists(continuous_norm_stats)
|
|
|
|
if self.continuous_need_norm:
|
|
self.register_buffer('continuous_norm_stats', tensor(continuous_norm_stats))
|
|
|
|
# defaults
|
|
|
|
self.register_buffer('default_discrete_action_types', arange(self.num_discrete_action_types), persistent = False)
|
|
self.register_buffer('default_continuous_action_types', arange(self.num_continuous_action_types), persistent = False)
|
|
|
|
# calculate offsets
|
|
|
|
offsets = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
|
|
self.register_buffer('discrete_action_offsets', offsets, persistent = False)
|
|
|
|
# unembedding
|
|
|
|
self.can_unembed = can_unembed
|
|
|
|
if can_unembed:
|
|
unembed_dim = default(unembed_dim, dim)
|
|
self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, unembed_dim) * 1e-2)
|
|
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2)
|
|
|
|
@property
|
|
def device(self):
|
|
return self.discrete_action_offsets.device
|
|
|
|
@property
|
|
def has_actions(self):
|
|
return self.num_discrete_action_types > 0 or self.num_continuous_action_types > 0
|
|
|
|
def unembed(
|
|
self,
|
|
embeds, # (... d)
|
|
discrete_action_types = None, # (na)
|
|
continuous_action_types = None, # (na)
|
|
|
|
): # (... discrete_na), (... continuous_na 2)
|
|
|
|
assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init'
|
|
|
|
assert not exists(discrete_action_types), 'selecting subset of discrete action types to unembed not completed yet'
|
|
|
|
# discrete actions
|
|
|
|
discrete_action_logits = None
|
|
|
|
if self.num_discrete_action_types > 0:
|
|
discrete_action_logits = einsum(embeds, self.discrete_action_unembed, '... d, na d -> ... na')
|
|
|
|
# continuous actions
|
|
|
|
continuous_action_mean_log_var = None
|
|
|
|
if self.num_continuous_action_types > 0:
|
|
|
|
continuous_action_unembed = self.continuous_action_unembed
|
|
|
|
if exists(continuous_action_types):
|
|
continuous_action_unembed = continuous_action_unembed[continuous_action_types]
|
|
|
|
continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na d two -> ... na two')
|
|
|
|
return discrete_action_logits, continuous_action_mean_log_var
|
|
|
|
def forward(
|
|
self,
|
|
*,
|
|
discrete_actions = None, # (... na)
|
|
continuous_actions = None, # (... na)
|
|
discrete_action_types = None, # (na)
|
|
continuous_action_types = None, # (na)
|
|
return_sum_pooled_embeds = True
|
|
):
|
|
|
|
discrete_embeds = continuous_embeds = None
|
|
|
|
if exists(discrete_actions):
|
|
|
|
discrete_action_types = default(discrete_action_types, self.default_discrete_action_types)
|
|
|
|
if exists(discrete_action_types) and not is_tensor(discrete_action_types):
|
|
if isinstance(discrete_action_types, int):
|
|
discrete_action_types = (discrete_action_types,)
|
|
|
|
discrete_action_types = tensor(discrete_action_types, device = self.device)
|
|
|
|
offsets = self.discrete_action_offsets[discrete_action_types]
|
|
|
|
assert offsets.shape[-1] == discrete_actions.shape[-1], 'mismatched number of discrete actions'
|
|
|
|
# offset the discrete actions based on the action types passed in (by default all discrete actions) and the calculated offset
|
|
|
|
discrete_actions_offsetted = einx.add('... na, na', discrete_actions, offsets)
|
|
discrete_embeds = self.discrete_action_embed(discrete_actions_offsetted)
|
|
|
|
if exists(continuous_actions):
|
|
continuous_action_types = default(continuous_action_types, self.default_continuous_action_types)
|
|
|
|
if exists(continuous_action_types) and not is_tensor(continuous_action_types):
|
|
if isinstance(continuous_action_types, int):
|
|
continuous_action_types = (continuous_action_types,)
|
|
|
|
continuous_action_types = tensor(continuous_action_types, device = self.device)
|
|
|
|
assert continuous_action_types.shape[-1] == continuous_actions.shape[-1], 'mismatched number of continuous actions'
|
|
|
|
continuous_action_embed = self.continuous_action_embed(continuous_action_types)
|
|
|
|
# maybe normalization
|
|
|
|
if self.continuous_need_norm:
|
|
norm_mean, norm_std = self.continuous_norm_stats.unbind(dim = -1)
|
|
continuous_actions = (continuous_actions - norm_mean) / norm_std.clamp(min = 1e-6)
|
|
|
|
# continuous embed is just the selected continuous action type with the scale
|
|
|
|
continuous_embeds = einx.multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions)
|
|
|
|
# return not pooled
|
|
|
|
if not return_sum_pooled_embeds:
|
|
return ActionEmbeds(discrete_embeds, continuous_embeds)
|
|
|
|
# handle sum pooling, which is what they did in the paper for all the actions
|
|
|
|
pooled = 0.
|
|
|
|
if exists(discrete_embeds):
|
|
pooled = pooled + reduce(discrete_embeds, '... na d -> ... d', 'sum')
|
|
|
|
if exists(continuous_embeds):
|
|
pooled = pooled + reduce(continuous_embeds, '... na d -> ... d', 'sum')
|
|
|
|
return pooled
|
|
|
|
# generalized advantage estimate
|
|
|
|
@torch.no_grad()
|
|
def calc_gae(
|
|
rewards,
|
|
values,
|
|
masks,
|
|
gamma = 0.99,
|
|
lam = 0.95,
|
|
use_accelerated = None
|
|
):
|
|
assert values.shape[-1] == rewards.shape[-1]
|
|
use_accelerated = default(use_accelerated, rewards.is_cuda)
|
|
|
|
values = F.pad(values, (0, 1), value = 0.)
|
|
values, values_next = values[..., :-1], values[..., 1:]
|
|
|
|
delta = rewards + gamma * values_next * masks - values
|
|
gates = gamma * lam * masks
|
|
|
|
scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
|
|
|
|
gae = scan(gates, delta)
|
|
|
|
returns = gae + values
|
|
|
|
return returns
|
|
|
|
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
|
# https://jerryxio.ng/posts/nd-rope/
|
|
|
|
def _phi(m):
|
|
x = 2.
|
|
for _ in range(10):
|
|
x = (1. + x) ** (1. / (m + 1.))
|
|
return x
|
|
|
|
def make_directions(n, d):
|
|
g = _phi(d)
|
|
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
|
|
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
|
|
z = torch.fmod(i * alpha, 1.0)
|
|
directions = torch.erfinv(2.0 * z - 1.0)
|
|
directions = l2norm(directions)
|
|
return directions.float()
|
|
|
|
class GoldenGateRoPENd(Module):
|
|
def __init__(
|
|
self,
|
|
dim_pos,
|
|
heads,
|
|
dim_head,
|
|
rope_min_freq = 1.,
|
|
rope_max_freq = 10000.,
|
|
rope_p_zero_freqs = 0., # proportion of frequencies set to 0
|
|
):
|
|
super().__init__()
|
|
assert divisible_by(dim_head, 2)
|
|
|
|
n_freqs = dim_head // 2
|
|
n_zero_freqs = round(rope_p_zero_freqs * n_freqs)
|
|
|
|
omega = cat((
|
|
torch.zeros(n_zero_freqs),
|
|
rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
|
|
))
|
|
|
|
directions = make_directions(heads * n_freqs, dim_pos)
|
|
directions = rearrange(directions, '(h f) p -> h f p', h = heads)
|
|
|
|
omega_expanded = rearrange(omega, 'f -> f 1')
|
|
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
|
|
|
|
def forward(
|
|
self,
|
|
pos # (b n p)
|
|
):
|
|
|
|
freqs = rearrange(self.freqs, 'h f p -> h 1 f p')
|
|
positions = rearrange(pos.float(), 'n p -> 1 n 1 p')
|
|
|
|
# thetas for freqs and positions (batch, head, seq, freq)
|
|
|
|
theta = reduce(freqs * positions, 'h n f p -> h n f', 'sum')
|
|
|
|
return cat((theta, theta), dim = -1)
|
|
|
|
class Rotary1D(Module):
|
|
def __init__(
|
|
self,
|
|
dim_head,
|
|
theta = 10000.
|
|
):
|
|
super().__init__()
|
|
inv_freq = 1.0 / (theta ** (arange(0, dim_head, 2).float() / dim_head))
|
|
self.register_buffer('inv_freq', inv_freq)
|
|
|
|
def forward(
|
|
self,
|
|
seq_len
|
|
):
|
|
device, dtype = self.inv_freq.device, self.inv_freq.dtype
|
|
|
|
t = torch.arange(seq_len, device = device).type(dtype)
|
|
freqs = einsum(t, self.inv_freq, 'i, j -> i j')
|
|
|
|
return cat((freqs, freqs), dim = -1)
|
|
|
|
|
|
def apply_rotations(
|
|
rotations, # (h n d) | (n d)
|
|
t # (b h n d)
|
|
):
|
|
heads, dtype = t.shape[1], t.dtype
|
|
t = t.float()
|
|
|
|
# handle gqa for rotary
|
|
|
|
if rotations.ndim == 3 and rotations.shape[0] < heads:
|
|
rotary_heads = rotations.shape[0]
|
|
|
|
assert divisible_by(heads, rotary_heads)
|
|
groups = heads // rotary_heads
|
|
rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups)
|
|
|
|
x1, x2 = t.chunk(2, dim = -1)
|
|
rotated_half_t = cat((-x2, x1), dim = -1)
|
|
|
|
# rotate in the positions
|
|
|
|
rotated = t * rotations.cos() + rotated_half_t * rotations.sin()
|
|
return rotated.type(dtype)
|
|
|
|
# multi-head rmsnorm
|
|
|
|
class MultiHeadRMSNorm(Module):
|
|
def __init__(
|
|
self,
|
|
dim_head,
|
|
heads = 8
|
|
):
|
|
super().__init__()
|
|
self.scale = dim_head ** 0.5
|
|
self.gamma = Parameter(torch.zeros(heads, dim_head)) # weight decay friendly
|
|
|
|
def forward(
|
|
self,
|
|
x # (b h n d)
|
|
):
|
|
normed = l2norm(x)
|
|
scale = (self.gamma + 1.) * self.scale
|
|
return einx.multiply('... h n d, h d', normed, scale)
|
|
|
|
# naive attend
|
|
|
|
def naive_attend(
|
|
q, k, v,
|
|
softclamp_value = None,
|
|
scale = None,
|
|
causal = False,
|
|
causal_block_size = 1,
|
|
mask = None
|
|
):
|
|
|
|
if not exists(scale):
|
|
scale = q.shape[-1] ** -0.5
|
|
|
|
# grouped query attention
|
|
|
|
groups = q.shape[1] // k.shape[1]
|
|
|
|
q = rearrange(q, 'b (h g) ... -> b h g ...', g = groups)
|
|
|
|
# similarity
|
|
|
|
sim = einsum(q, k, 'b h g i d, b h j d -> b h g i j')
|
|
|
|
# scale and attention
|
|
|
|
sim = sim * scale
|
|
|
|
# softclamping a la gemma 3
|
|
|
|
if exists(softclamp_value):
|
|
sim = softclamp(sim, softclamp_value)
|
|
|
|
# masking
|
|
|
|
mask_value = -torch.finfo(sim.dtype).max
|
|
|
|
if exists(mask):
|
|
sim = sim.masked_fill(~mask, mask_value)
|
|
|
|
if causal:
|
|
is_blocked_causal = causal_block_size > 1
|
|
i, j = sim.shape[-2:]
|
|
|
|
if is_blocked_causal:
|
|
i = ceil(i / causal_block_size)
|
|
j = ceil(j / causal_block_size)
|
|
|
|
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
|
|
|
|
if causal_block_size > 1:
|
|
causal_mask = repeat(causal_mask, 'i j -> (i b1) (j b2)', b1 = causal_block_size, b2 = causal_block_size)
|
|
causal_mask = causal_mask[:sim.shape[-2], :sim.shape[-1]]
|
|
|
|
sim = sim.masked_fill(causal_mask, mask_value)
|
|
|
|
# attend
|
|
|
|
attn = sim.softmax(dim = -1)
|
|
|
|
# aggregate
|
|
|
|
out = einsum(attn, v, 'b h g i j, b h j d -> b h g i d')
|
|
|
|
# merge the groups
|
|
|
|
return rearrange(out, 'b h g i d -> b (h g) i d')
|
|
|
|
# flex attention related and factory function for attend depending on whether on cuda + flex attention available
|
|
|
|
def block_mask_causal(block_size):
|
|
|
|
def inner(b, h, q, k):
|
|
bq = q // block_size
|
|
bk = k // block_size
|
|
return bq >= bk
|
|
|
|
return inner
|
|
|
|
def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = False):
|
|
bq = q % seq_len
|
|
bk = k % seq_len
|
|
|
|
is_special_start_index = seq_len - num_tokens
|
|
|
|
q_is_special = q >= is_special_start_index
|
|
k_is_special = k >= is_special_start_index
|
|
|
|
if special_attend_only_itself:
|
|
out = ~(q_is_special & ~k_is_special) # modality attends to everything, but latent can only attend to itself (proposed attention pattern for encoder of video tokenizer)
|
|
else:
|
|
out = ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens
|
|
|
|
return out
|
|
|
|
def block_mask_special_tokens_right(
|
|
seq_len,
|
|
num_tokens
|
|
):
|
|
def inner(b, h, q, k):
|
|
return special_token_mask(q, k, seq_len, num_tokens)
|
|
return inner
|
|
|
|
def compose_mask(mask1, mask2):
|
|
def inner(b, h, q, k):
|
|
return mask1(b, h, q, k) & mask2(b, h, q, k)
|
|
|
|
return inner
|
|
|
|
def block_mask_noop(b, h, q, k):
|
|
return b >= 0
|
|
|
|
def score_mod_softclamp(value):
|
|
def inner(sim, b, h, q, k):
|
|
if not exists(value):
|
|
return sim
|
|
|
|
sim = sim / value
|
|
sim = torch.tanh(sim)
|
|
sim = sim * value
|
|
return sim
|
|
|
|
return inner
|
|
|
|
# factory for attend function
|
|
|
|
def get_attend_fn(
|
|
use_flex,
|
|
seq_len,
|
|
k_seq_len,
|
|
causal = False,
|
|
causal_block_size = 1,
|
|
softclamp_value = 50.,
|
|
num_special_tokens = 0, # special tokens are latents / agents
|
|
block_size_per_special = None, # defaults to k_seq_len
|
|
special_attend_only_itself = False, # by default, modality only attends to itself while special sees everything, but if turned True, will be the inverse - special can only attend to itself but modality can attend everything
|
|
device = None
|
|
):
|
|
block_size_per_special = default(block_size_per_special, k_seq_len)
|
|
|
|
if use_flex:
|
|
# flex pathway
|
|
|
|
block_mask_fn = block_mask_causal(causal_block_size) if causal else block_mask_noop
|
|
|
|
if num_special_tokens > 0:
|
|
special_block_mask = block_mask_special_tokens_right(block_size_per_special, num_special_tokens, special_attend_only_itself)
|
|
block_mask_fn = compose_mask(block_mask_fn, special_block_mask)
|
|
|
|
block_mask = create_block_mask(block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = k_seq_len)
|
|
|
|
score_mod = score_mod_softclamp(softclamp_value)
|
|
attend_fn = partial(flex_attention, block_mask = block_mask, score_mod = score_mod, enable_gqa = True)
|
|
else:
|
|
# naive pathway
|
|
|
|
mask = None
|
|
if num_special_tokens > 0:
|
|
q_seq = torch.arange(seq_len, device = device)[:, None]
|
|
k_seq = torch.arange(k_seq_len, device = device)[None, :]
|
|
|
|
mask = special_token_mask(q_seq, k_seq, block_size_per_special, num_special_tokens, special_attend_only_itself)
|
|
|
|
attend_fn = partial(naive_attend, causal = causal, causal_block_size = causal_block_size, mask = mask, softclamp_value = softclamp_value)
|
|
|
|
return attend_fn
|
|
|
|
# attention
|
|
|
|
class Attention(Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
dim_head = 64,
|
|
query_heads = None,
|
|
heads = 8,
|
|
pre_rmsnorm = True,
|
|
):
|
|
super().__init__()
|
|
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
|
|
|
# setup grouped query attention
|
|
|
|
query_heads = default(query_heads, heads)
|
|
assert query_heads >= heads and divisible_by(query_heads, heads)
|
|
|
|
# scaling, splitting and merging of heads
|
|
|
|
self.split_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
|
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
|
|
dim_q_inner = dim_head * query_heads
|
|
dim_kv_inner = dim_head * heads
|
|
|
|
self.to_q = LinearNoBias(dim, dim_q_inner)
|
|
self.to_kv = LinearNoBias(dim, dim_kv_inner * 2)
|
|
self.to_out = LinearNoBias(dim_q_inner, dim)
|
|
|
|
# stability related
|
|
|
|
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
|
|
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
|
|
|
def forward(
|
|
self,
|
|
tokens, # (b n d)
|
|
kv_cache = None,
|
|
return_kv_cache = False,
|
|
rotary_pos_emb = None,
|
|
attend_fn: Callable | None = None
|
|
):
|
|
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
|
|
|
tokens = self.norm(tokens)
|
|
|
|
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
|
|
|
|
# split heads
|
|
|
|
q, k, v = map(self.split_heads, (q, k, v))
|
|
|
|
# qk rmsnorm
|
|
|
|
q = self.q_heads_rmsnorm(q)
|
|
k = self.k_heads_rmsnorm(k)
|
|
|
|
# caching
|
|
|
|
if exists(kv_cache):
|
|
ck, cv = kv_cache
|
|
k = cat((ck, k), dim = -2)
|
|
v = cat((cv, v), dim = -2)
|
|
|
|
# rotary
|
|
|
|
if exists(rotary_pos_emb):
|
|
q = apply_rotations(rotary_pos_emb, q)
|
|
k = apply_rotations(rotary_pos_emb, k)
|
|
|
|
# attention
|
|
|
|
attend_fn = default(attend_fn, naive_attend)
|
|
|
|
out = attend_fn(q, k, v)
|
|
|
|
# merge heads
|
|
|
|
out = self.merge_heads(out)
|
|
|
|
# combine heads
|
|
|
|
out = self.to_out(out)
|
|
|
|
out = inverse_packed_batch(out)
|
|
|
|
if not return_kv_cache:
|
|
return out
|
|
|
|
return out, stack((k, v))
|
|
|
|
# feedforward
|
|
|
|
class SwiGLUFeedforward(Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
expansion_factor = 4,
|
|
pre_rmsnorm = True
|
|
):
|
|
super().__init__()
|
|
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
|
|
|
dim_inner = int(dim * expansion_factor * 2 / 3)
|
|
|
|
self.proj_in = Linear(dim, dim_inner * 2)
|
|
self.proj_out = Linear(dim_inner, dim)
|
|
|
|
def forward(self, x):
|
|
x = self.norm(x)
|
|
|
|
x, gates = self.proj_in(x).chunk(2, dim = -1)
|
|
x = x * F.gelu(gates)
|
|
|
|
return self.proj_out(x)
|
|
|
|
# video tokenizer
|
|
|
|
class VideoTokenizer(Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
dim_latent,
|
|
patch_size,
|
|
image_height = None,
|
|
image_width = None,
|
|
num_latent_tokens = 4,
|
|
encoder_depth = 4,
|
|
decoder_depth = 4,
|
|
attn_kwargs: dict = dict(),
|
|
attn_dim_head = 64,
|
|
attn_heads = 8,
|
|
attn_softclamp_value = 50.,
|
|
ff_kwargs: dict = dict(),
|
|
decoder_pos_mlp_depth = 2,
|
|
channels = 3,
|
|
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
|
lpips_loss_network: Module | None = None,
|
|
lpips_loss_weight = 0.2,
|
|
nd_rotary_kwargs: dict = dict(
|
|
rope_min_freq = 1.,
|
|
rope_max_freq = 10000.,
|
|
rope_p_zero_freqs = 0.
|
|
)
|
|
):
|
|
super().__init__()
|
|
|
|
self.patch_size = patch_size
|
|
|
|
# special tokens
|
|
|
|
assert num_latent_tokens >= 1
|
|
self.num_latent_tokens = num_latent_tokens
|
|
self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2)
|
|
|
|
# mae masking - Kaiming He paper from long ago
|
|
|
|
self.per_image_patch_mask_prob = per_image_patch_mask_prob
|
|
self.mask_token = Parameter(randn(dim) * 1e-2)
|
|
|
|
# patch and unpatch
|
|
|
|
dim_patch = channels * patch_size ** 2
|
|
|
|
self.patch_to_tokens = Sequential(
|
|
Rearrange('b c t (h p1) (w p2) -> b t h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
|
Linear(dim_patch, dim)
|
|
)
|
|
|
|
self.tokens_to_patch = Sequential(
|
|
Linear(dim, dim_patch),
|
|
Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
|
|
)
|
|
|
|
# 3d rotations
|
|
|
|
self.spacetime_rotary = GoldenGateRoPENd(
|
|
dim_pos = 3,
|
|
heads = attn_heads,
|
|
dim_head = attn_dim_head,
|
|
**nd_rotary_kwargs
|
|
)
|
|
|
|
# attention related
|
|
|
|
self.attn_softclamp_value = attn_softclamp_value
|
|
|
|
# encoder
|
|
|
|
encoder_layers = []
|
|
|
|
for _ in range(encoder_depth):
|
|
encoder_layers.append(ModuleList([
|
|
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
|
|
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
|
]))
|
|
|
|
self.encoder_layers = ModuleList(encoder_layers)
|
|
self.encoder_norm = RMSNorm(dim)
|
|
|
|
# latents
|
|
|
|
self.encoded_to_latents = Sequential(
|
|
LinearNoBias(dim, dim_latent),
|
|
nn.Tanh(),
|
|
)
|
|
|
|
self.latents_to_decoder = LinearNoBias(dim_latent, dim)
|
|
|
|
# decoder
|
|
|
|
self.image_height = image_height
|
|
self.image_width = image_width
|
|
|
|
# parameterize the decoder positional embeddings for MAE style training so it can be resolution agnostic
|
|
|
|
self.to_decoder_pos_emb = create_mlp(
|
|
dim_in = 2,
|
|
dim = dim * 2,
|
|
dim_out = dim,
|
|
depth = decoder_pos_mlp_depth,
|
|
)
|
|
|
|
decoder_layers = []
|
|
|
|
for _ in range(decoder_depth):
|
|
decoder_layers.append(ModuleList([
|
|
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
|
|
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
|
]))
|
|
|
|
self.decoder_layers = ModuleList(decoder_layers)
|
|
self.decoder_norm = RMSNorm(dim)
|
|
|
|
# loss related
|
|
|
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
|
|
self.has_lpips_loss = lpips_loss_weight > 0.
|
|
self.lpips_loss_weight = lpips_loss_weight
|
|
|
|
if self.has_lpips_loss:
|
|
self.lpips = LPIPSLoss(lpips_loss_network)
|
|
|
|
@property
|
|
def device(self):
|
|
return self.zero.device
|
|
|
|
@torch.no_grad()
|
|
def tokenize(
|
|
self,
|
|
video
|
|
):
|
|
self.eval()
|
|
return self.forward(video, return_latents = True)
|
|
|
|
def get_rotary_pos_emb(
|
|
self,
|
|
time,
|
|
num_patch_height,
|
|
num_patch_width
|
|
):
|
|
device = self.device
|
|
|
|
positions = stack(torch.meshgrid(
|
|
arange(time, device = device),
|
|
arange(num_patch_height, device = device),
|
|
arange(num_patch_width, device = device)
|
|
), dim = -1)
|
|
|
|
positions = rearrange(positions, 't h w p -> t (h w) p')
|
|
|
|
# give the latents an out of bounds position and assume the network will figure it out
|
|
|
|
positions = pad_at_dim(positions, (0, self.num_latent_tokens), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated
|
|
|
|
positions = rearrange(positions, 't hw p -> (t hw) p')
|
|
|
|
return self.spacetime_rotary(positions)
|
|
|
|
def decode(
|
|
self,
|
|
latents, # (b t n d)
|
|
height = None,
|
|
width = None,
|
|
rotary_pos_emb = None
|
|
): # (b c t h w)
|
|
|
|
height = default(height, self.image_height)
|
|
width = default(width, self.image_width)
|
|
|
|
assert exists(height) and exists(width), f'image height and width need to be passed in when decoding latents'
|
|
|
|
batch, time, device = *latents.shape[:2], latents.device
|
|
|
|
use_flex = latents.is_cuda and exists(flex_attention)
|
|
|
|
num_patch_height = height // self.patch_size
|
|
num_patch_width = width // self.patch_size
|
|
|
|
if not exists(rotary_pos_emb):
|
|
rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width)
|
|
|
|
# latents to tokens
|
|
|
|
latent_tokens = self.latents_to_decoder(latents)
|
|
|
|
# generate decoder positional embedding and concat the latent token
|
|
|
|
spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device)
|
|
spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device)
|
|
|
|
space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1)
|
|
|
|
decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor)
|
|
decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time)
|
|
|
|
tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d')
|
|
|
|
# pack time
|
|
|
|
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
|
|
|
|
seq_len = tokens.shape[-2]
|
|
|
|
# decoder attend
|
|
|
|
decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, causal = True, num_special_tokens = self.num_latent_tokens, special_attend_only_itself = True)
|
|
|
|
# decoder attention
|
|
|
|
for attn, ff in self.decoder_layers:
|
|
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
|
|
|
|
tokens = ff(tokens) + tokens
|
|
|
|
tokens = self.decoder_norm(tokens)
|
|
|
|
# unpack time
|
|
|
|
tokens = inverse_pack_time(tokens)
|
|
|
|
# unpack latents
|
|
|
|
tokens, latent_tokens = unpack(tokens, packed_latent_shape, 'b t * d')
|
|
|
|
# project back to patches
|
|
|
|
recon_video = self.tokens_to_patch(tokens)
|
|
|
|
return recon_video
|
|
|
|
def forward(
|
|
self,
|
|
video, # (b c t h w)
|
|
return_latents = False,
|
|
mask_patches = None,
|
|
return_all_losses = False
|
|
):
|
|
batch, _, time, height, width = video.shape
|
|
patch_size, device = self.patch_size, video.device
|
|
|
|
assert divisible_by(height, patch_size) and divisible_by(width, patch_size)
|
|
|
|
# to tokens
|
|
|
|
tokens = self.patch_to_tokens(video)
|
|
|
|
# get some dimensions
|
|
|
|
num_patch_height, num_patch_width, _ = tokens.shape[-3:]
|
|
|
|
# rotary positions
|
|
|
|
rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width)
|
|
|
|
# masking
|
|
|
|
mask_patches = default(mask_patches, self.training)
|
|
|
|
if mask_patches:
|
|
min_mask_prob, max_mask_prob = self.per_image_patch_mask_prob
|
|
|
|
mask_prob = torch.empty(tokens.shape[:2], device = tokens.device).uniform_(min_mask_prob, max_mask_prob) # (b t)
|
|
|
|
mask_prob = repeat(mask_prob, 'b t -> b t vh vw', vh = tokens.shape[2], vw = tokens.shape[3])
|
|
mask_patch = torch.bernoulli(mask_prob) == 1.
|
|
|
|
tokens = einx.where('..., d, ... d', mask_patch, self.mask_token, tokens)
|
|
|
|
# pack space
|
|
|
|
tokens, inverse_pack_space = pack_one(tokens, 'b t * d')
|
|
|
|
# add the latent
|
|
|
|
latents = repeat(self.latent_tokens, 'n d -> b t n d', b = tokens.shape[0], t = tokens.shape[1])
|
|
|
|
tokens, packed_latent_shape = pack((tokens, latents), 'b t * d')
|
|
|
|
space_seq_len = tokens.shape[-2]
|
|
|
|
# pack time
|
|
|
|
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
|
|
|
|
seq_len = tokens.shape[1]
|
|
|
|
# attend hyper parameters
|
|
|
|
attend_kwargs = dict(
|
|
causal = True,
|
|
causal_block_size = space_seq_len,
|
|
softclamp_value = self.attn_softclamp_value,
|
|
block_size_per_special = space_seq_len,
|
|
num_special_tokens = 1
|
|
)
|
|
|
|
use_flex = tokens.is_cuda and exists(flex_attention)
|
|
|
|
# encoder attend
|
|
|
|
# modality can only attend to itself while latents can attend to everything
|
|
# similar to agent token in dynamics model
|
|
|
|
encoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, causal = True, num_special_tokens = self.num_latent_tokens, special_attend_only_itself = False)
|
|
|
|
# encoder
|
|
|
|
for attn, ff in self.encoder_layers:
|
|
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn) + tokens
|
|
tokens = ff(tokens) + tokens
|
|
|
|
tokens = self.encoder_norm(tokens)
|
|
|
|
# latent bottleneck
|
|
|
|
tokens = inverse_pack_time(tokens)
|
|
|
|
tokens, latents = unpack(tokens, packed_latent_shape, 'b t * d')
|
|
|
|
latents = self.encoded_to_latents(latents)
|
|
|
|
if return_latents:
|
|
return latents
|
|
|
|
recon_video = self.decode(latents, height = height, width = width, rotary_pos_emb = rotary_pos_emb)
|
|
|
|
# losses
|
|
|
|
recon_loss = F.mse_loss(video, recon_video)
|
|
|
|
lpips_loss = self.zero
|
|
|
|
if self.has_lpips_loss:
|
|
lpips_loss = self.lpips(video, recon_video)
|
|
|
|
# losses
|
|
|
|
total_loss = (
|
|
recon_loss +
|
|
lpips_loss * self.lpips_loss_weight
|
|
)
|
|
|
|
if not return_all_losses:
|
|
return total_loss
|
|
|
|
losses = (recon_loss, lpips_loss)
|
|
|
|
return total_loss, TokenizerLosses(losses)
|
|
|
|
# dynamics model, axial space-time transformer
|
|
|
|
class DynamicsWorldModel(Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
dim_latent,
|
|
video_tokenizer: VideoTokenizer | None = None,
|
|
max_steps = 64, # K_max in paper
|
|
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
|
num_spatial_tokens = 2, # latents projected to greater number of spatial tokens
|
|
num_latent_tokens = None,
|
|
num_agents = 1,
|
|
num_tasks = 0,
|
|
reward_encoder_kwargs: dict = dict(),
|
|
depth = 4,
|
|
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
|
time_block_every = 4, # every 4th block is time
|
|
attn_kwargs: dict = dict(
|
|
heads = 8,
|
|
),
|
|
attn_dim_head = 64,
|
|
attn_softclamp_value = 50.,
|
|
ff_kwargs: dict = dict(),
|
|
loss_weight_fn: Callable = ramp_weight,
|
|
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
|
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
|
add_reward_embed_to_agent_token = False,
|
|
add_reward_embed_dropout = 0.1,
|
|
num_discrete_actions: int | tuple[int, ...] = 0,
|
|
num_continuous_actions = 0,
|
|
continuous_norm_stats = None,
|
|
reward_loss_weight = 0.1,
|
|
value_head_mlp_depth = 3,
|
|
policy_head_mlp_depth = 3,
|
|
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
|
):
|
|
super().__init__()
|
|
|
|
# can accept raw video if tokenizer is passed in
|
|
|
|
self.video_tokenizer = video_tokenizer
|
|
|
|
if exists(video_tokenizer):
|
|
num_latent_tokens = default(num_latent_tokens, video_tokenizer.num_latent_tokens)
|
|
assert video_tokenizer.num_latent_tokens == num_latent_tokens, f'`num_latent_tokens` must be the same for the tokenizer and dynamics model'
|
|
|
|
assert exists(num_latent_tokens), '`num_latent_tokens` must be set'
|
|
|
|
# spatial
|
|
|
|
self.num_latent_tokens = num_latent_tokens
|
|
self.dim_latent = dim_latent
|
|
self.latent_shape = (num_latent_tokens, dim_latent)
|
|
|
|
if num_spatial_tokens >= num_latent_tokens:
|
|
assert divisible_by(num_spatial_tokens, num_latent_tokens)
|
|
|
|
expand_factor = num_spatial_tokens // num_latent_tokens
|
|
|
|
self.latents_to_spatial_tokens = Sequential(
|
|
Linear(dim_latent, dim * expand_factor),
|
|
Rearrange('... (s d) -> ... s d', s = expand_factor)
|
|
)
|
|
|
|
self.to_latent_pred = Sequential(
|
|
Reduce('b t n s d -> b t n d', 'mean'),
|
|
RMSNorm(dim),
|
|
LinearNoBias(dim, dim_latent)
|
|
)
|
|
|
|
else:
|
|
assert divisible_by(num_latent_tokens, num_spatial_tokens)
|
|
latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
|
|
|
|
self.latents_to_spatial_tokens = Sequential(
|
|
Rearrange('b t n d -> b t (n d)'),
|
|
Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
|
|
Rearrange('b t (s d) -> b t s d', s = num_spatial_tokens)
|
|
)
|
|
|
|
self.to_latent_pred = Sequential(
|
|
RMSNorm(dim),
|
|
LinearNoBias(dim, dim_latent * latent_tokens_to_space),
|
|
Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
|
|
)
|
|
|
|
# register tokens
|
|
|
|
self.num_register_tokens = num_register_tokens
|
|
self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
|
|
|
# signal and step sizes
|
|
|
|
assert divisible_by(dim, 2)
|
|
dim_half = dim // 2
|
|
|
|
assert is_power_two(max_steps), '`max_steps` must be a power of 2'
|
|
self.max_steps = max_steps
|
|
self.num_step_sizes_log2 = int(log2(max_steps))
|
|
|
|
self.signal_levels_embed = nn.Embedding(max_steps, dim_half)
|
|
self.step_size_embed = nn.Embedding(self.num_step_sizes_log2, dim_half) # power of 2, so 1/1, 1/2, 1/4, 1/8 ... 1/Kmax
|
|
|
|
self.prob_no_shortcut_train = default(prob_no_shortcut_train, self.num_step_sizes_log2 ** -1.)
|
|
|
|
# loss related
|
|
|
|
self.pred_orig_latent = pred_orig_latent # x-space or v-space
|
|
self.loss_weight_fn = loss_weight_fn
|
|
|
|
# reinforcement related
|
|
|
|
# they sum all the actions into a single token
|
|
|
|
self.num_agents = num_agents
|
|
self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
|
|
|
self.num_tasks = num_tasks
|
|
self.task_embed = nn.Embedding(num_tasks, dim)
|
|
|
|
# learned set of latent genes
|
|
|
|
self.agent_has_genes = num_latent_genes > 0
|
|
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
|
|
|
|
# action embedder
|
|
|
|
self.action_embedder = ActionEmbedder(
|
|
dim = dim,
|
|
num_discrete_actions = num_discrete_actions,
|
|
num_continuous_actions = num_continuous_actions,
|
|
continuous_norm_stats = continuous_norm_stats
|
|
)
|
|
|
|
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
|
|
|
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
|
self.add_reward_embed_dropout = add_reward_embed_dropout
|
|
|
|
self.reward_encoder = SymExpTwoHot(
|
|
**reward_encoder_kwargs,
|
|
dim_embed = dim,
|
|
learned_embedding = add_reward_embed_to_agent_token
|
|
)
|
|
|
|
self.to_reward_pred = Sequential(
|
|
RMSNorm(dim),
|
|
LinearNoBias(dim, self.reward_encoder.num_bins)
|
|
)
|
|
|
|
self.reward_loss_weight = reward_loss_weight
|
|
|
|
# policy head
|
|
|
|
self.policy_head = create_mlp(
|
|
dim_in = dim,
|
|
dim = dim * 4,
|
|
dim_out = dim,
|
|
depth = policy_head_mlp_depth
|
|
)
|
|
|
|
# value head
|
|
|
|
self.value_head = create_mlp(
|
|
dim_in = dim,
|
|
dim = dim * 4,
|
|
dim_out = self.reward_encoder.num_bins,
|
|
depth = value_head_mlp_depth,
|
|
)
|
|
|
|
# attention
|
|
|
|
self.attn_softclamp_value = attn_softclamp_value
|
|
|
|
# time rotary embedding
|
|
|
|
self.time_rotary = Rotary1D(attn_dim_head)
|
|
|
|
# transformer
|
|
|
|
layers = []
|
|
is_time = []
|
|
|
|
for i in range(depth):
|
|
layer_index = i + 1
|
|
|
|
is_time_block = divisible_by(layer_index, time_block_every)
|
|
is_time.append(is_time_block)
|
|
|
|
rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity()
|
|
rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity()
|
|
|
|
layers.append(ModuleList([
|
|
rearrange_to_attend,
|
|
rearrange_from_attend,
|
|
Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs),
|
|
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
|
]))
|
|
|
|
self.layers = ModuleList(layers)
|
|
self.is_time = is_time
|
|
|
|
# zero
|
|
|
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
|
|
@property
|
|
def device(self):
|
|
return self.zero.device
|
|
|
|
def get_times_from_signal_level(
|
|
self,
|
|
signal_levels,
|
|
align_dims_left_to = None
|
|
):
|
|
times = signal_levels.float() / self.max_steps
|
|
|
|
if not exists(align_dims_left_to):
|
|
return times
|
|
|
|
return align_dims_left(times, align_dims_left_to)
|
|
|
|
def parameter(self):
|
|
params = super().parameters()
|
|
|
|
if not exists(self.video_tokenizer):
|
|
return params
|
|
|
|
return list(set(params) - set(self.video_tokenizer.parameters()))
|
|
|
|
@torch.no_grad()
|
|
def generate(
|
|
self,
|
|
time_steps,
|
|
num_steps = 4,
|
|
batch_size = 1,
|
|
agent_index = 0,
|
|
image_height = None,
|
|
image_width = None,
|
|
return_decoded_video = None,
|
|
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
|
return_rewards_per_frame = False
|
|
|
|
): # (b t n d) | (b c t h w)
|
|
|
|
was_training = self.training
|
|
self.eval()
|
|
|
|
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
|
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
|
|
|
latent_shape = self.latent_shape
|
|
|
|
# derive step size
|
|
|
|
step_size = self.max_steps // num_steps
|
|
|
|
# denoising
|
|
# teacher forcing to start with
|
|
|
|
latents = empty((batch_size, 0, *latent_shape), device = self.device)
|
|
|
|
past_context_noise = latents.clone()
|
|
|
|
# maybe return rewards
|
|
|
|
if return_rewards_per_frame:
|
|
decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32)
|
|
|
|
# while all the frames of the video (per latent) is not generated
|
|
|
|
while latents.shape[1] < time_steps:
|
|
|
|
curr_time_steps = latents.shape[1]
|
|
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
|
|
|
|
for step in range(num_steps):
|
|
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
|
|
|
noised_context = latents.lerp(past_context_noise, context_signal_noise) # the paragraph after eq (8)
|
|
|
|
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
|
|
|
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
|
|
|
pred, agent_embed = self.forward(
|
|
latents = noised_latent_with_context,
|
|
signal_levels = signal_levels_with_context,
|
|
step_sizes = step_size,
|
|
rewards = decoded_rewards,
|
|
latent_is_noised = True,
|
|
return_pred_only = True,
|
|
return_agent_tokens = True
|
|
)
|
|
|
|
_, pred = unpack(pred, pack_context_shape, 'b * n d')
|
|
|
|
# derive flow, based on whether in x-space or not
|
|
|
|
if self.pred_orig_latent:
|
|
times = self.get_times_from_signal_level(signal_levels, noised_latent)
|
|
flow = (pred - noised_latent) / (1. - times)
|
|
else:
|
|
flow = pred
|
|
|
|
# denoise
|
|
|
|
noised_latent += flow * (step_size / self.max_steps)
|
|
|
|
denoised_latent = noised_latent # it is now denoised
|
|
|
|
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
|
|
|
if return_rewards_per_frame:
|
|
one_agent_embed = agent_embed[:, -1:, agent_index]
|
|
|
|
reward_logits = self.to_reward_pred(one_agent_embed)
|
|
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
|
|
|
|
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
|
|
|
# concat the denoised latent
|
|
|
|
latents = cat((latents, denoised_latent), dim = 1)
|
|
|
|
# add new fixed context noise for the temporal consistency
|
|
|
|
past_context_noise = cat((past_context_noise, randn_like(denoised_latent)), dim = 1)
|
|
|
|
# restore state
|
|
|
|
self.train(was_training)
|
|
|
|
# returning video
|
|
|
|
has_tokenizer = exists(self.video_tokenizer)
|
|
return_decoded_video = default(return_decoded_video, has_tokenizer)
|
|
|
|
if not return_decoded_video:
|
|
if not return_rewards_per_frame:
|
|
return denoised_latents
|
|
|
|
return denoised_latents, decoded_rewards
|
|
|
|
generated_video = self.video_tokenizer.decode(
|
|
latents,
|
|
height = image_height,
|
|
width = image_width
|
|
)
|
|
|
|
if not return_rewards_per_frame:
|
|
return generated_video
|
|
|
|
return generated_video, decoded_rewards
|
|
|
|
def forward(
|
|
self,
|
|
*,
|
|
video = None, # (b c t vh vw)
|
|
latents = None, # (b t n d) | (b t d)
|
|
signal_levels = None, # () | (b) | (b t)
|
|
step_sizes = None, # () | (b)
|
|
step_sizes_log2 = None, # () | (b)
|
|
latent_gene_ids = None, # (b)
|
|
tasks = None, # (b)
|
|
rewards = None, # (b t)
|
|
discrete_actions = None, # (b t na)
|
|
continuous_actions = None, # (b t na)
|
|
discrete_action_types = None, # (na)
|
|
continuous_action_types = None, # (na)
|
|
return_pred_only = False,
|
|
latent_is_noised = False,
|
|
return_all_losses = False,
|
|
return_agent_tokens = False
|
|
):
|
|
# handle video or latents
|
|
|
|
assert exists(video) ^ exists(latents)
|
|
|
|
if exists(video):
|
|
assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
|
|
|
|
latents = self.video_tokenizer.tokenize(video)
|
|
|
|
if latents.ndim == 3:
|
|
latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
|
|
|
|
assert latents.shape[-2:] == self.latent_shape
|
|
|
|
# variables
|
|
|
|
batch, time, device = *latents.shape[:2], latents.device
|
|
|
|
# signal and step size related input conforming
|
|
|
|
if exists(signal_levels):
|
|
if isinstance(signal_levels, int):
|
|
signal_levels = tensor(signal_levels, device = self.device)
|
|
|
|
if signal_levels.ndim == 0:
|
|
signal_levels = repeat(signal_levels, '-> b', b = batch)
|
|
|
|
if signal_levels.ndim == 1:
|
|
signal_levels = repeat(signal_levels, 'b -> b t', t = time)
|
|
|
|
if exists(step_sizes):
|
|
if isinstance(step_sizes, int):
|
|
step_sizes = tensor(step_sizes, device = self.device)
|
|
|
|
if step_sizes.ndim == 0:
|
|
step_sizes = repeat(step_sizes, '-> b', b = batch)
|
|
|
|
if exists(step_sizes_log2):
|
|
if isinstance(step_sizes_log2, int):
|
|
step_sizes_log2 = tensor(step_sizes_log2, device = self.device)
|
|
|
|
if step_sizes_log2.ndim == 0:
|
|
step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch)
|
|
|
|
# handle step sizes -> step size log2
|
|
|
|
assert not (exists(step_sizes) and exists(step_sizes_log2))
|
|
|
|
if exists(step_sizes):
|
|
step_sizes_log2_maybe_float = torch.log2(step_sizes)
|
|
step_sizes_log2 = step_sizes_log2_maybe_float.long()
|
|
|
|
assert (step_sizes_log2 == step_sizes_log2_maybe_float).all(), f'`step_sizes` must be powers of 2'
|
|
|
|
# flow related
|
|
|
|
assert not (exists(signal_levels) ^ exists(step_sizes_log2))
|
|
|
|
is_inference = exists(signal_levels)
|
|
no_shortcut_train = not is_inference
|
|
|
|
return_pred_only = return_pred_only or latent_is_noised
|
|
|
|
# if neither signal levels or step sizes passed in, assume training
|
|
# generate them randomly for training
|
|
|
|
if not is_inference:
|
|
|
|
no_shortcut_train = sample_prob(self.prob_no_shortcut_train)
|
|
|
|
if no_shortcut_train:
|
|
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
|
|
# in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this
|
|
|
|
step_sizes_log2 = zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
|
|
signal_levels = randint(0, self.max_steps, (batch, time), device = device)
|
|
else:
|
|
|
|
# now we follow eq (4)
|
|
|
|
step_sizes_log2 = randint(1, self.num_step_sizes_log2, (batch,), device = device)
|
|
num_step_sizes = 2 ** step_sizes_log2
|
|
|
|
signal_levels = randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
|
|
|
|
# times is from 0 to 1
|
|
|
|
times = self.get_times_from_signal_level(signal_levels, latents)
|
|
|
|
if not latent_is_noised:
|
|
# get the noise
|
|
|
|
noise = randn_like(latents)
|
|
|
|
# noise from 0 as noise to 1 as data
|
|
|
|
noised_latents = noise.lerp(latents, times)
|
|
|
|
else:
|
|
noised_latents = latents
|
|
|
|
# reinforcement learning related
|
|
|
|
agent_tokens = repeat(self.action_learned_embed, '... d -> b ... d', b = batch)
|
|
|
|
if exists(tasks):
|
|
assert self.num_tasks > 0
|
|
|
|
task_embeds = self.task_embed(tasks)
|
|
agent_tokens = einx.add('b ... d, b d', agent_tokens, task_embeds)
|
|
|
|
# maybe evolution
|
|
|
|
if exists(latent_gene_ids):
|
|
assert exists(self.latent_genes)
|
|
latent_genes = self.latent_genes[latent_gene_ids]
|
|
|
|
agent_tokens = einx.add('b ... d, b d', agent_tokens, latent_genes)
|
|
|
|
# handle agent tokens w/ actions and task embeds
|
|
|
|
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
|
|
|
# maybe add the action embed to the agent tokens per time step
|
|
|
|
if exists(discrete_actions) or exists(continuous_actions):
|
|
assert self.action_embedder.has_actions
|
|
|
|
action_embed = self.action_embedder(
|
|
discrete_actions = discrete_actions,
|
|
discrete_action_types = discrete_action_types,
|
|
continuous_actions = continuous_actions,
|
|
continuous_action_types = continuous_action_types
|
|
)
|
|
|
|
agent_tokens = einx.add('b t ... d, b t d', agent_tokens, action_embed)
|
|
|
|
# maybe add a reward embedding to agent tokens
|
|
|
|
if exists(rewards):
|
|
two_hot_encoding = self.reward_encoder(rewards)
|
|
|
|
if (
|
|
self.add_reward_embed_to_agent_token and
|
|
(not self.training or not sample_prob(self.add_reward_embed_dropout)) # a bit of noise goes a long way
|
|
):
|
|
reward_embeds = self.reward_encoder.embed(two_hot_encoding)
|
|
|
|
pop_last_reward = int(reward_embeds.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case
|
|
|
|
reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
|
|
|
|
agent_tokens = einx.add('b t ... d, b t d', agent_tokens, reward_embeds)
|
|
|
|
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
|
|
|
def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = False):
|
|
# latents to spatial tokens
|
|
|
|
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
|
|
|
space_tokens, inverse_pack_space_per_latent = pack_one(space_tokens, 'b t * d')
|
|
|
|
num_spatial_tokens = space_tokens.shape[-2]
|
|
|
|
# pack to tokens
|
|
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
|
|
|
|
registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
|
|
|
|
# determine signal + step size embed for their diffusion forcing + shortcut
|
|
|
|
signal_embed = self.signal_levels_embed(signal_levels)
|
|
|
|
step_size_embed = self.step_size_embed(step_sizes_log2)
|
|
step_size_embed = repeat(step_size_embed, 'b ... -> b t ...', t = time)
|
|
|
|
flow_token = cat((signal_embed, step_size_embed), dim = -1)
|
|
flow_token = rearrange(flow_token, 'b t d -> b t d')
|
|
|
|
# pack to tokens for attending
|
|
|
|
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_tokens], 'b t * d')
|
|
|
|
# attend functions for space and time
|
|
|
|
seq_len = tokens.shape[1]
|
|
|
|
use_flex = exists(flex_attention) and tokens.is_cuda
|
|
|
|
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device)
|
|
|
|
space_seq_len = (
|
|
+ 1 # signal + step
|
|
+ self.num_agents # action / agent tokens
|
|
+ self.num_register_tokens
|
|
+ num_spatial_tokens
|
|
)
|
|
|
|
space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = self.num_agents, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
|
|
|
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
|
|
|
# rotary
|
|
|
|
rotary_pos_emb = self.time_rotary(time)
|
|
|
|
# attention
|
|
|
|
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
|
|
|
tokens = pre_attn_rearrange(tokens)
|
|
|
|
# when is a axial time attention block, should be causal
|
|
|
|
attend_fn = time_attend if layer_is_time else space_attend
|
|
|
|
layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
|
|
|
|
# attention layer
|
|
|
|
tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens
|
|
|
|
tokens = post_attn_rearrange(tokens)
|
|
|
|
# feedforward layer
|
|
|
|
tokens = ff(tokens) + tokens
|
|
|
|
# unpack
|
|
|
|
flow_token, space_tokens, register_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
|
|
|
# pooling
|
|
|
|
space_tokens = inverse_pack_space_per_latent(space_tokens)
|
|
|
|
pred = self.to_latent_pred(space_tokens)
|
|
|
|
if not return_agent_tokens:
|
|
return pred
|
|
|
|
return pred, agent_tokens
|
|
|
|
# forward the network
|
|
|
|
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True)
|
|
|
|
if return_pred_only:
|
|
if not return_agent_tokens:
|
|
return pred
|
|
|
|
return pred, encoded_agent_tokens
|
|
|
|
# determine the target for the loss
|
|
|
|
pred_target = None
|
|
|
|
is_x_space = self.pred_orig_latent
|
|
is_v_space_pred = not self.pred_orig_latent
|
|
|
|
maybe_shortcut_loss_weight = 1.
|
|
|
|
if no_shortcut_train:
|
|
|
|
# allow for original velocity pred
|
|
# x-space as in paper is in else clause
|
|
|
|
if is_v_space_pred:
|
|
pred_target = flow = latents - noise
|
|
else:
|
|
pred_target = latents
|
|
else:
|
|
# shortcut training - Frans et al. https://arxiv.org/abs/2410.12557
|
|
|
|
# basically a consistency loss where you ensure quantity of two half steps equals one step
|
|
# dreamer then makes it works for x-space with some math
|
|
|
|
get_prediction_no_grad = torch.no_grad()(get_prediction)
|
|
|
|
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
|
half_step_size = 2 ** step_sizes_log2_minus_one
|
|
|
|
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, agent_tokens)
|
|
|
|
# first derive b'
|
|
|
|
if is_v_space_pred:
|
|
first_step_pred_flow = first_step_pred
|
|
else:
|
|
first_times = self.get_times_from_signal_level(signal_levels, noised_latents)
|
|
first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times)
|
|
|
|
# take a half step
|
|
|
|
half_step_size_align_left = align_dims_left(half_step_size, noised_latents)
|
|
|
|
denoised_latent = noised_latents + first_step_pred_flow * (half_step_size_align_left / self.max_steps)
|
|
|
|
# get second prediction for b''
|
|
|
|
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
|
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, agent_tokens)
|
|
|
|
if is_v_space_pred:
|
|
second_step_pred_flow = second_step_pred
|
|
else:
|
|
second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised_latent)
|
|
second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times)
|
|
|
|
# pred target is sg(b' + b'') / 2
|
|
|
|
pred_target = (first_step_pred_flow + second_step_pred_flow).detach() / 2
|
|
|
|
# need to convert x-space to v-space
|
|
|
|
if is_x_space:
|
|
pred = (pred - noised_latents) / (1. - first_times)
|
|
maybe_shortcut_loss_weight = (1. - first_times) ** 2
|
|
|
|
# mse loss
|
|
|
|
flow_losses = F.mse_loss(pred, pred_target, reduction = 'none')
|
|
|
|
flow_losses = flow_losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
|
|
|
|
# loss weighting with their ramp function
|
|
|
|
if exists(self.loss_weight_fn):
|
|
loss_weight = self.loss_weight_fn(times)
|
|
flow_losses = flow_losses * loss_weight
|
|
|
|
flow_loss = flow_losses.mean()
|
|
|
|
# now take care of the agent token losses
|
|
|
|
reward_loss = self.zero
|
|
|
|
if exists(rewards):
|
|
|
|
if rewards.ndim == 2: # (b t)
|
|
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
|
|
|
|
reward_pred = self.to_reward_pred(encoded_agent_tokens)
|
|
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
|
|
|
|
# gather losses
|
|
|
|
total_loss = (
|
|
flow_loss +
|
|
reward_loss * self.reward_loss_weight
|
|
)
|
|
|
|
if not return_all_losses:
|
|
return total_loss
|
|
|
|
return total_loss, (flow_loss, reward_loss)
|
|
|
|
# dreamer
|
|
|
|
class Dreamer(Module):
|
|
def __init__(
|
|
self,
|
|
video_tokenizer: VideoTokenizer,
|
|
dynamics_model: DynamicsModel,
|
|
discount_factor = 0.997
|
|
):
|
|
super().__init__()
|