last commit for the day

This commit is contained in:
lucidrains 2025-11-09 11:12:37 -08:00
parent e1c41f4371
commit 73029635fe
2 changed files with 16 additions and 5 deletions

View File

@ -223,6 +223,14 @@ def mean_log_var_to_distr(
std = (0.5 * log_var).exp()
return Normal(mean, std)
def safe_stack(tensors, dim = 0):
tensors = [*filter(exists, tensors)]
if len(tensors) == 0:
return None
return stack(tensors, dim = dim)
def safe_cat(tensors, dim):
tensors = [*filter(exists, tensors)]
@ -1602,8 +1610,8 @@ class AxialSpaceTimeTransformer(Module):
intermediates = TransformerIntermediates(
stack(time_attn_kv_caches),
stack(normed_time_attn_inputs),
stack(normed_space_attn_inputs)
safe_stack(normed_time_attn_inputs),
safe_stack(normed_space_attn_inputs)
)
return out, intermediates
@ -1881,8 +1889,11 @@ class VideoTokenizer(Module):
time_decorr_loss = space_decorr_loss = self.zero
if self.encoder_add_decor_aux_loss:
time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
if exists(time_attn_normed_inputs):
time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
if exists(space_attn_normed_inputs):
space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
# losses

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.1.10"
version = "0.1.12"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }