From 73029635fe5c083d043f809e34573d6795fad1ab Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 9 Nov 2025 11:12:37 -0800 Subject: [PATCH] last commit for the day --- dreamer4/dreamer4.py | 19 +++++++++++++++---- pyproject.toml | 2 +- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 2fab2bf..ac73ecf 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -223,6 +223,14 @@ def mean_log_var_to_distr( std = (0.5 * log_var).exp() return Normal(mean, std) +def safe_stack(tensors, dim = 0): + tensors = [*filter(exists, tensors)] + + if len(tensors) == 0: + return None + + return stack(tensors, dim = dim) + def safe_cat(tensors, dim): tensors = [*filter(exists, tensors)] @@ -1602,8 +1610,8 @@ class AxialSpaceTimeTransformer(Module): intermediates = TransformerIntermediates( stack(time_attn_kv_caches), - stack(normed_time_attn_inputs), - stack(normed_space_attn_inputs) + safe_stack(normed_time_attn_inputs), + safe_stack(normed_space_attn_inputs) ) return out, intermediates @@ -1881,8 +1889,11 @@ class VideoTokenizer(Module): time_decorr_loss = space_decorr_loss = self.zero if self.encoder_add_decor_aux_loss: - time_decorr_loss = self.decorr_loss(time_attn_normed_inputs) - space_decorr_loss = self.decorr_loss(space_attn_normed_inputs) + if exists(time_attn_normed_inputs): + time_decorr_loss = self.decorr_loss(time_attn_normed_inputs) + + if exists(space_attn_normed_inputs): + space_decorr_loss = self.decorr_loss(space_attn_normed_inputs) # losses diff --git a/pyproject.toml b/pyproject.toml index 4e1beec..b4ed20b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.1.10" +version = "0.1.12" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }