add a decorrelation loss for temporal attention in encoder of video tokenizer
This commit is contained in:
parent
256a81f658
commit
0c4224da18
@ -27,6 +27,8 @@ from x_mlps_pytorch.normed_mlp import create_mlp
|
||||
|
||||
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
||||
|
||||
from vit_pytorch.vit_with_decorr import DecorrelationLoss
|
||||
|
||||
from assoc_scan import AssocScan
|
||||
|
||||
# ein related
|
||||
@ -68,10 +70,12 @@ except ImportError:
|
||||
|
||||
LinearNoBias = partial(Linear, bias = False)
|
||||
|
||||
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
||||
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr'))
|
||||
|
||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
||||
|
||||
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
|
||||
|
||||
MaybeTensor = Tensor | None
|
||||
|
||||
@dataclass
|
||||
@ -1321,7 +1325,7 @@ class Attention(Module):
|
||||
self,
|
||||
tokens, # (b n d)
|
||||
kv_cache = None,
|
||||
return_kv_cache = False,
|
||||
return_intermediates = False,
|
||||
rotary_pos_emb = None,
|
||||
attend_fn: Callable | None = None
|
||||
):
|
||||
@ -1375,10 +1379,10 @@ class Attention(Module):
|
||||
|
||||
out = inverse_packed_batch(out)
|
||||
|
||||
if not return_kv_cache:
|
||||
if not return_intermediates:
|
||||
return out
|
||||
|
||||
return out, stack((k, v))
|
||||
return out, AttentionIntermediates(stack((k, v)), tokens)
|
||||
|
||||
# feedforward
|
||||
|
||||
@ -1492,7 +1496,7 @@ class AxialSpaceTimeTransformer(Module):
|
||||
self,
|
||||
tokens, # (b t s d)
|
||||
kv_cache: Tensor | None = None, # (y 2 b h t d)
|
||||
return_kv_cache = False
|
||||
return_intermediates = False
|
||||
|
||||
): # (b t s d) | (y 2 b h t d)
|
||||
|
||||
@ -1533,6 +1537,10 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
||||
|
||||
# normed attention inputs
|
||||
|
||||
normed_time_attn_inputs = []
|
||||
|
||||
# attention
|
||||
|
||||
tokens = self.expand_streams(tokens)
|
||||
@ -1553,12 +1561,12 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
# attention layer
|
||||
|
||||
tokens, next_kv_cache = attn(
|
||||
tokens, attn_intermediates = attn(
|
||||
tokens,
|
||||
rotary_pos_emb = layer_rotary_pos_emb,
|
||||
attend_fn = attend_fn,
|
||||
kv_cache = maybe_kv_cache,
|
||||
return_kv_cache = True
|
||||
return_intermediates = True
|
||||
)
|
||||
|
||||
tokens = post_attn_rearrange(tokens)
|
||||
@ -1570,7 +1578,12 @@ class AxialSpaceTimeTransformer(Module):
|
||||
# save kv cache if is time layer
|
||||
|
||||
if layer_is_time:
|
||||
time_attn_kv_caches.append(next_kv_cache)
|
||||
time_attn_kv_caches.append(attn_intermediates.next_kv_cache)
|
||||
|
||||
# save time attention inputs for decorr
|
||||
|
||||
if layer_is_time:
|
||||
normed_time_attn_inputs.append(attn_intermediates.normed_inputs)
|
||||
|
||||
tokens = self.reduce_streams(tokens)
|
||||
|
||||
@ -1580,10 +1593,15 @@ class AxialSpaceTimeTransformer(Module):
|
||||
# just concat the past tokens back on for now, todo - clean up the logic
|
||||
out = cat((past_tokens, out), dim = 1)
|
||||
|
||||
if not return_kv_cache:
|
||||
if not return_intermediates:
|
||||
return out
|
||||
|
||||
return out, stack(time_attn_kv_caches)
|
||||
intermediates = (
|
||||
stack(time_attn_kv_caches),
|
||||
stack(normed_time_attn_inputs)
|
||||
)
|
||||
|
||||
return out, intermediates
|
||||
|
||||
# video tokenizer
|
||||
|
||||
@ -1609,12 +1627,15 @@ class VideoTokenizer(Module):
|
||||
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
||||
lpips_loss_network: Module | None = None,
|
||||
lpips_loss_weight = 0.2,
|
||||
encoder_add_decor_aux_loss = False,
|
||||
decor_auxx_loss_weight = 0.1,
|
||||
decorr_sample_frac = 0.25,
|
||||
nd_rotary_kwargs: dict = dict(
|
||||
rope_min_freq = 1.,
|
||||
rope_max_freq = 10000.,
|
||||
rope_p_zero_freqs = 0.
|
||||
),
|
||||
num_residual_streams = 1
|
||||
num_residual_streams = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -1709,6 +1730,14 @@ class VideoTokenizer(Module):
|
||||
if self.has_lpips_loss:
|
||||
self.lpips = LPIPSLoss(lpips_loss_network)
|
||||
|
||||
# decorr aux loss
|
||||
# https://arxiv.org/abs/2510.14657
|
||||
|
||||
self.encoder_add_decor_aux_loss = encoder_add_decor_aux_loss
|
||||
self.decorr_aux_loss_weight = decor_auxx_loss_weight
|
||||
|
||||
self.decorr_loss = DecorrelationLoss(decorr_sample_frac) if encoder_add_decor_aux_loss else None
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.zero.device
|
||||
@ -1822,7 +1851,7 @@ class VideoTokenizer(Module):
|
||||
|
||||
# encoder attention
|
||||
|
||||
tokens = self.encoder_transformer(tokens)
|
||||
tokens, (_, time_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
|
||||
|
||||
# latent bottleneck
|
||||
|
||||
@ -1844,17 +1873,21 @@ class VideoTokenizer(Module):
|
||||
if self.has_lpips_loss:
|
||||
lpips_loss = self.lpips(video, recon_video)
|
||||
|
||||
if self.encoder_add_decor_aux_loss:
|
||||
decorr_loss = self.decorr_loss(time_attn_normed_inputs)
|
||||
|
||||
# losses
|
||||
|
||||
total_loss = (
|
||||
recon_loss +
|
||||
lpips_loss * self.lpips_loss_weight
|
||||
lpips_loss * self.lpips_loss_weight +
|
||||
decorr_loss * self.decorr_aux_loss_weight
|
||||
)
|
||||
|
||||
if not return_all_losses:
|
||||
return total_loss
|
||||
|
||||
losses = (recon_loss, lpips_loss)
|
||||
losses = (recon_loss, lpips_loss, decorr_loss)
|
||||
|
||||
return total_loss, TokenizerLosses(*losses)
|
||||
|
||||
@ -3335,7 +3368,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# attention
|
||||
|
||||
tokens, next_time_kv_cache = self.transformer(tokens, kv_cache = time_kv_cache, return_kv_cache = True)
|
||||
tokens, (next_time_kv_cache, _) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
|
||||
|
||||
# unpack
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.1.5"
|
||||
version = "0.1.6"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
@ -36,7 +36,8 @@ dependencies = [
|
||||
"hyper-connections>=0.2.1",
|
||||
"torch>=2.4",
|
||||
"torchvision",
|
||||
"x-mlps-pytorch>=0.0.29"
|
||||
"x-mlps-pytorch>=0.0.29",
|
||||
"vit-pytorch>=1.15.2"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@ -41,7 +41,9 @@ def test_e2e(
|
||||
patch_size = 32,
|
||||
attn_dim_head = 16,
|
||||
num_latent_tokens = 4,
|
||||
num_residual_streams = num_residual_streams
|
||||
num_residual_streams = num_residual_streams,
|
||||
encoder_add_decor_aux_loss = True,
|
||||
decorr_sample_frac = 1.
|
||||
)
|
||||
|
||||
video = torch.randn(2, 3, 4, 256, 256)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user