redo so lens is always the episode length, including the bootstrap value timestep, and use is_truncated to mask out the bootstrap node from being learned on

This commit is contained in:
lucidrains 2025-10-27 08:06:21 -07:00
parent 59c458aea3
commit 349a03acd7
3 changed files with 7 additions and 11 deletions

View File

@ -2206,7 +2206,7 @@ class DynamicsWorldModel(Module):
values = values,
step_size = step_size,
agent_index = agent_index,
lens = full((batch,), max_timesteps, device = device),
lens = full((batch,), max_timesteps + 1, device = device),
is_from_world_model = False
)
@ -2239,16 +2239,12 @@ class DynamicsWorldModel(Module):
# calculate returns
# for truncated (true by default), we will calculate experience lens + 1 and mask out anything after that
# for terminated, will just mask out any after lens
# if not supplied, assume truncated (which is the case for games like minecraft or capped timesteps)
# mask out anything after the `lens`, which may include a bootstrapped node at the very end if `is_truncated = True`
if not exists(experience.is_truncated):
experience.is_truncated = full((batch,), True, device = latents.device)
lens_for_gae_calc = torch.where(experience.is_truncated, experience.lens, experience.lens + 1)
mask_for_gae = lens_to_mask(lens_for_gae_calc, time)
mask_for_gae = lens_to_mask(experience.lens, time)
rewards = rewards.masked_fill(mask_for_gae, 0.)
old_values = old_values.masked_fill(mask_for_gae, 0.)
@ -2263,8 +2259,8 @@ class DynamicsWorldModel(Module):
is_var_len = exists(experience.lens)
if is_var_len:
lens = experience.lens
mask = lens_to_mask(lens, max_time)
learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
mask = lens_to_mask(learnable_lens, max_time)
# determine whether to finetune entire transformer or just learn the heads

View File

@ -287,7 +287,7 @@ class DreamTrainer(Module):
for _ in range(self.num_train_steps):
dreams = self.unwrapped_model.generate(
self.generate_timesteps,
self.generate_timesteps + 1, # plus one for bootstrap value
batch_size = self.batch_size,
return_rewards_per_frame = True,
return_agent_actions = True,

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.78"
version = "0.0.79"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }