last commit for the day

This commit is contained in:
lucidrains 2025-11-09 11:03:03 -08:00
parent e1c41f4371
commit dfe15a0605
2 changed files with 16 additions and 5 deletions

View File

@ -223,6 +223,14 @@ def mean_log_var_to_distr(
std = (0.5 * log_var).exp() std = (0.5 * log_var).exp()
return Normal(mean, std) return Normal(mean, std)
def safe_stack(tensors, dim):
tensors = [*filter(exists, tensors)]
if len(tensors) == 0:
return None
return stack(tensors, dim = dim)
def safe_cat(tensors, dim): def safe_cat(tensors, dim):
tensors = [*filter(exists, tensors)] tensors = [*filter(exists, tensors)]
@ -1602,8 +1610,8 @@ class AxialSpaceTimeTransformer(Module):
intermediates = TransformerIntermediates( intermediates = TransformerIntermediates(
stack(time_attn_kv_caches), stack(time_attn_kv_caches),
stack(normed_time_attn_inputs), safe_stack(normed_time_attn_inputs),
stack(normed_space_attn_inputs) safe_stack(normed_space_attn_inputs)
) )
return out, intermediates return out, intermediates
@ -1881,8 +1889,11 @@ class VideoTokenizer(Module):
time_decorr_loss = space_decorr_loss = self.zero time_decorr_loss = space_decorr_loss = self.zero
if self.encoder_add_decor_aux_loss: if self.encoder_add_decor_aux_loss:
time_decorr_loss = self.decorr_loss(time_attn_normed_inputs) if exists(time_attn_normed_inputs):
space_decorr_loss = self.decorr_loss(space_attn_normed_inputs) time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
if exists(space_attn_normed_inputs):
space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
# losses # losses

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.1.10" version = "0.1.11"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }