last commit for the day
This commit is contained in:
parent
e1c41f4371
commit
73029635fe
@ -223,6 +223,14 @@ def mean_log_var_to_distr(
|
||||
std = (0.5 * log_var).exp()
|
||||
return Normal(mean, std)
|
||||
|
||||
def safe_stack(tensors, dim = 0):
|
||||
tensors = [*filter(exists, tensors)]
|
||||
|
||||
if len(tensors) == 0:
|
||||
return None
|
||||
|
||||
return stack(tensors, dim = dim)
|
||||
|
||||
def safe_cat(tensors, dim):
|
||||
tensors = [*filter(exists, tensors)]
|
||||
|
||||
@ -1602,8 +1610,8 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
intermediates = TransformerIntermediates(
|
||||
stack(time_attn_kv_caches),
|
||||
stack(normed_time_attn_inputs),
|
||||
stack(normed_space_attn_inputs)
|
||||
safe_stack(normed_time_attn_inputs),
|
||||
safe_stack(normed_space_attn_inputs)
|
||||
)
|
||||
|
||||
return out, intermediates
|
||||
@ -1881,8 +1889,11 @@ class VideoTokenizer(Module):
|
||||
time_decorr_loss = space_decorr_loss = self.zero
|
||||
|
||||
if self.encoder_add_decor_aux_loss:
|
||||
time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
|
||||
space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
|
||||
if exists(time_attn_normed_inputs):
|
||||
time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
|
||||
|
||||
if exists(space_attn_normed_inputs):
|
||||
space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
|
||||
|
||||
# losses
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.1.10"
|
||||
version = "0.1.12"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user