last commit for the day
This commit is contained in:
parent
e1c41f4371
commit
dfe15a0605
@ -223,6 +223,14 @@ def mean_log_var_to_distr(
|
|||||||
std = (0.5 * log_var).exp()
|
std = (0.5 * log_var).exp()
|
||||||
return Normal(mean, std)
|
return Normal(mean, std)
|
||||||
|
|
||||||
|
def safe_stack(tensors, dim):
|
||||||
|
tensors = [*filter(exists, tensors)]
|
||||||
|
|
||||||
|
if len(tensors) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return stack(tensors, dim = dim)
|
||||||
|
|
||||||
def safe_cat(tensors, dim):
|
def safe_cat(tensors, dim):
|
||||||
tensors = [*filter(exists, tensors)]
|
tensors = [*filter(exists, tensors)]
|
||||||
|
|
||||||
@ -1602,8 +1610,8 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
|
|
||||||
intermediates = TransformerIntermediates(
|
intermediates = TransformerIntermediates(
|
||||||
stack(time_attn_kv_caches),
|
stack(time_attn_kv_caches),
|
||||||
stack(normed_time_attn_inputs),
|
safe_stack(normed_time_attn_inputs),
|
||||||
stack(normed_space_attn_inputs)
|
safe_stack(normed_space_attn_inputs)
|
||||||
)
|
)
|
||||||
|
|
||||||
return out, intermediates
|
return out, intermediates
|
||||||
@ -1881,8 +1889,11 @@ class VideoTokenizer(Module):
|
|||||||
time_decorr_loss = space_decorr_loss = self.zero
|
time_decorr_loss = space_decorr_loss = self.zero
|
||||||
|
|
||||||
if self.encoder_add_decor_aux_loss:
|
if self.encoder_add_decor_aux_loss:
|
||||||
time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
|
if exists(time_attn_normed_inputs):
|
||||||
space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
|
time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
|
||||||
|
|
||||||
|
if exists(space_attn_normed_inputs):
|
||||||
|
space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
|
||||||
|
|
||||||
# losses
|
# losses
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.1.10"
|
version = "0.1.11"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user