diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 9931c6e..467c091 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -2244,10 +2244,11 @@ class DynamicsWorldModel(Module): if not exists(experience.is_truncated): experience.is_truncated = full((batch,), True, device = latents.device) - mask_for_gae = lens_to_mask(experience.lens, time) + if exists(experience.lens): + mask_for_gae = lens_to_mask(experience.lens, time) - rewards = rewards.masked_fill(mask_for_gae, 0.) - old_values = old_values.masked_fill(mask_for_gae, 0.) + rewards = rewards.masked_fill(mask_for_gae, 0.) + old_values = old_values.masked_fill(mask_for_gae, 0.) # calculate returns diff --git a/pyproject.toml b/pyproject.toml index 3fcc1b0..2691c8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.79" +version = "0.0.80" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }