add a decorrelation loss for temporal attention in encoder of video tokenizer
This commit is contained in:
parent
256a81f658
commit
0c4224da18
@ -27,6 +27,8 @@ from x_mlps_pytorch.normed_mlp import create_mlp
|
|||||||
|
|
||||||
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
||||||
|
|
||||||
|
from vit_pytorch.vit_with_decorr import DecorrelationLoss
|
||||||
|
|
||||||
from assoc_scan import AssocScan
|
from assoc_scan import AssocScan
|
||||||
|
|
||||||
# ein related
|
# ein related
|
||||||
@ -68,10 +70,12 @@ except ImportError:
|
|||||||
|
|
||||||
LinearNoBias = partial(Linear, bias = False)
|
LinearNoBias = partial(Linear, bias = False)
|
||||||
|
|
||||||
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr'))
|
||||||
|
|
||||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
||||||
|
|
||||||
|
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
|
||||||
|
|
||||||
MaybeTensor = Tensor | None
|
MaybeTensor = Tensor | None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -1321,7 +1325,7 @@ class Attention(Module):
|
|||||||
self,
|
self,
|
||||||
tokens, # (b n d)
|
tokens, # (b n d)
|
||||||
kv_cache = None,
|
kv_cache = None,
|
||||||
return_kv_cache = False,
|
return_intermediates = False,
|
||||||
rotary_pos_emb = None,
|
rotary_pos_emb = None,
|
||||||
attend_fn: Callable | None = None
|
attend_fn: Callable | None = None
|
||||||
):
|
):
|
||||||
@ -1375,10 +1379,10 @@ class Attention(Module):
|
|||||||
|
|
||||||
out = inverse_packed_batch(out)
|
out = inverse_packed_batch(out)
|
||||||
|
|
||||||
if not return_kv_cache:
|
if not return_intermediates:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
return out, stack((k, v))
|
return out, AttentionIntermediates(stack((k, v)), tokens)
|
||||||
|
|
||||||
# feedforward
|
# feedforward
|
||||||
|
|
||||||
@ -1492,7 +1496,7 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
self,
|
self,
|
||||||
tokens, # (b t s d)
|
tokens, # (b t s d)
|
||||||
kv_cache: Tensor | None = None, # (y 2 b h t d)
|
kv_cache: Tensor | None = None, # (y 2 b h t d)
|
||||||
return_kv_cache = False
|
return_intermediates = False
|
||||||
|
|
||||||
): # (b t s d) | (y 2 b h t d)
|
): # (b t s d) | (y 2 b h t d)
|
||||||
|
|
||||||
@ -1533,6 +1537,10 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
|
|
||||||
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
||||||
|
|
||||||
|
# normed attention inputs
|
||||||
|
|
||||||
|
normed_time_attn_inputs = []
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
tokens = self.expand_streams(tokens)
|
tokens = self.expand_streams(tokens)
|
||||||
@ -1553,12 +1561,12 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
|
|
||||||
# attention layer
|
# attention layer
|
||||||
|
|
||||||
tokens, next_kv_cache = attn(
|
tokens, attn_intermediates = attn(
|
||||||
tokens,
|
tokens,
|
||||||
rotary_pos_emb = layer_rotary_pos_emb,
|
rotary_pos_emb = layer_rotary_pos_emb,
|
||||||
attend_fn = attend_fn,
|
attend_fn = attend_fn,
|
||||||
kv_cache = maybe_kv_cache,
|
kv_cache = maybe_kv_cache,
|
||||||
return_kv_cache = True
|
return_intermediates = True
|
||||||
)
|
)
|
||||||
|
|
||||||
tokens = post_attn_rearrange(tokens)
|
tokens = post_attn_rearrange(tokens)
|
||||||
@ -1570,7 +1578,12 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
# save kv cache if is time layer
|
# save kv cache if is time layer
|
||||||
|
|
||||||
if layer_is_time:
|
if layer_is_time:
|
||||||
time_attn_kv_caches.append(next_kv_cache)
|
time_attn_kv_caches.append(attn_intermediates.next_kv_cache)
|
||||||
|
|
||||||
|
# save time attention inputs for decorr
|
||||||
|
|
||||||
|
if layer_is_time:
|
||||||
|
normed_time_attn_inputs.append(attn_intermediates.normed_inputs)
|
||||||
|
|
||||||
tokens = self.reduce_streams(tokens)
|
tokens = self.reduce_streams(tokens)
|
||||||
|
|
||||||
@ -1580,10 +1593,15 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
# just concat the past tokens back on for now, todo - clean up the logic
|
# just concat the past tokens back on for now, todo - clean up the logic
|
||||||
out = cat((past_tokens, out), dim = 1)
|
out = cat((past_tokens, out), dim = 1)
|
||||||
|
|
||||||
if not return_kv_cache:
|
if not return_intermediates:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
return out, stack(time_attn_kv_caches)
|
intermediates = (
|
||||||
|
stack(time_attn_kv_caches),
|
||||||
|
stack(normed_time_attn_inputs)
|
||||||
|
)
|
||||||
|
|
||||||
|
return out, intermediates
|
||||||
|
|
||||||
# video tokenizer
|
# video tokenizer
|
||||||
|
|
||||||
@ -1609,12 +1627,15 @@ class VideoTokenizer(Module):
|
|||||||
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
||||||
lpips_loss_network: Module | None = None,
|
lpips_loss_network: Module | None = None,
|
||||||
lpips_loss_weight = 0.2,
|
lpips_loss_weight = 0.2,
|
||||||
|
encoder_add_decor_aux_loss = False,
|
||||||
|
decor_auxx_loss_weight = 0.1,
|
||||||
|
decorr_sample_frac = 0.25,
|
||||||
nd_rotary_kwargs: dict = dict(
|
nd_rotary_kwargs: dict = dict(
|
||||||
rope_min_freq = 1.,
|
rope_min_freq = 1.,
|
||||||
rope_max_freq = 10000.,
|
rope_max_freq = 10000.,
|
||||||
rope_p_zero_freqs = 0.
|
rope_p_zero_freqs = 0.
|
||||||
),
|
),
|
||||||
num_residual_streams = 1
|
num_residual_streams = 1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -1709,6 +1730,14 @@ class VideoTokenizer(Module):
|
|||||||
if self.has_lpips_loss:
|
if self.has_lpips_loss:
|
||||||
self.lpips = LPIPSLoss(lpips_loss_network)
|
self.lpips = LPIPSLoss(lpips_loss_network)
|
||||||
|
|
||||||
|
# decorr aux loss
|
||||||
|
# https://arxiv.org/abs/2510.14657
|
||||||
|
|
||||||
|
self.encoder_add_decor_aux_loss = encoder_add_decor_aux_loss
|
||||||
|
self.decorr_aux_loss_weight = decor_auxx_loss_weight
|
||||||
|
|
||||||
|
self.decorr_loss = DecorrelationLoss(decorr_sample_frac) if encoder_add_decor_aux_loss else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
return self.zero.device
|
return self.zero.device
|
||||||
@ -1822,7 +1851,7 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
# encoder attention
|
# encoder attention
|
||||||
|
|
||||||
tokens = self.encoder_transformer(tokens)
|
tokens, (_, time_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
|
||||||
|
|
||||||
# latent bottleneck
|
# latent bottleneck
|
||||||
|
|
||||||
@ -1844,17 +1873,21 @@ class VideoTokenizer(Module):
|
|||||||
if self.has_lpips_loss:
|
if self.has_lpips_loss:
|
||||||
lpips_loss = self.lpips(video, recon_video)
|
lpips_loss = self.lpips(video, recon_video)
|
||||||
|
|
||||||
|
if self.encoder_add_decor_aux_loss:
|
||||||
|
decorr_loss = self.decorr_loss(time_attn_normed_inputs)
|
||||||
|
|
||||||
# losses
|
# losses
|
||||||
|
|
||||||
total_loss = (
|
total_loss = (
|
||||||
recon_loss +
|
recon_loss +
|
||||||
lpips_loss * self.lpips_loss_weight
|
lpips_loss * self.lpips_loss_weight +
|
||||||
|
decorr_loss * self.decorr_aux_loss_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_all_losses:
|
if not return_all_losses:
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
losses = (recon_loss, lpips_loss)
|
losses = (recon_loss, lpips_loss, decorr_loss)
|
||||||
|
|
||||||
return total_loss, TokenizerLosses(*losses)
|
return total_loss, TokenizerLosses(*losses)
|
||||||
|
|
||||||
@ -3335,7 +3368,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
tokens, next_time_kv_cache = self.transformer(tokens, kv_cache = time_kv_cache, return_kv_cache = True)
|
tokens, (next_time_kv_cache, _) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
|
||||||
|
|
||||||
# unpack
|
# unpack
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.1.5"
|
version = "0.1.6"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
@ -36,7 +36,8 @@ dependencies = [
|
|||||||
"hyper-connections>=0.2.1",
|
"hyper-connections>=0.2.1",
|
||||||
"torch>=2.4",
|
"torch>=2.4",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"x-mlps-pytorch>=0.0.29"
|
"x-mlps-pytorch>=0.0.29",
|
||||||
|
"vit-pytorch>=1.15.2"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|||||||
@ -41,7 +41,9 @@ def test_e2e(
|
|||||||
patch_size = 32,
|
patch_size = 32,
|
||||||
attn_dim_head = 16,
|
attn_dim_head = 16,
|
||||||
num_latent_tokens = 4,
|
num_latent_tokens = 4,
|
||||||
num_residual_streams = num_residual_streams
|
num_residual_streams = num_residual_streams,
|
||||||
|
encoder_add_decor_aux_loss = True,
|
||||||
|
decorr_sample_frac = 1.
|
||||||
)
|
)
|
||||||
|
|
||||||
video = torch.randn(2, 3, 4, 256, 256)
|
video = torch.randn(2, 3, 4, 256, 256)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user