diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 0748113..9931c6e 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -2206,7 +2206,7 @@ class DynamicsWorldModel(Module): values = values, step_size = step_size, agent_index = agent_index, - lens = full((batch,), max_timesteps, device = device), + lens = full((batch,), max_timesteps + 1, device = device), is_from_world_model = False ) @@ -2239,16 +2239,12 @@ class DynamicsWorldModel(Module): # calculate returns - # for truncated (true by default), we will calculate experience lens + 1 and mask out anything after that - # for terminated, will just mask out any after lens - - # if not supplied, assume truncated (which is the case for games like minecraft or capped timesteps) + # mask out anything after the `lens`, which may include a bootstrapped node at the very end if `is_truncated = True` if not exists(experience.is_truncated): experience.is_truncated = full((batch,), True, device = latents.device) - lens_for_gae_calc = torch.where(experience.is_truncated, experience.lens, experience.lens + 1) - mask_for_gae = lens_to_mask(lens_for_gae_calc, time) + mask_for_gae = lens_to_mask(experience.lens, time) rewards = rewards.masked_fill(mask_for_gae, 0.) old_values = old_values.masked_fill(mask_for_gae, 0.) @@ -2263,8 +2259,8 @@ class DynamicsWorldModel(Module): is_var_len = exists(experience.lens) if is_var_len: - lens = experience.lens - mask = lens_to_mask(lens, max_time) + learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value + mask = lens_to_mask(learnable_lens, max_time) # determine whether to finetune entire transformer or just learn the heads diff --git a/dreamer4/trainers.py b/dreamer4/trainers.py index 3ebbd95..7004bd4 100644 --- a/dreamer4/trainers.py +++ b/dreamer4/trainers.py @@ -287,7 +287,7 @@ class DreamTrainer(Module): for _ in range(self.num_train_steps): dreams = self.unwrapped_model.generate( - self.generate_timesteps, + self.generate_timesteps + 1, # plus one for bootstrap value batch_size = self.batch_size, return_rewards_per_frame = True, return_agent_actions = True, diff --git a/pyproject.toml b/pyproject.toml index 56ded6a..3fcc1b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.78" +version = "0.0.79" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }