redo so lens is always the episode length, including the bootstrap value timestep, and use is_truncated to mask out the bootstrap node from being learned on
This commit is contained in:
parent
59c458aea3
commit
349a03acd7
@ -2206,7 +2206,7 @@ class DynamicsWorldModel(Module):
|
||||
values = values,
|
||||
step_size = step_size,
|
||||
agent_index = agent_index,
|
||||
lens = full((batch,), max_timesteps, device = device),
|
||||
lens = full((batch,), max_timesteps + 1, device = device),
|
||||
is_from_world_model = False
|
||||
)
|
||||
|
||||
@ -2239,16 +2239,12 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# calculate returns
|
||||
|
||||
# for truncated (true by default), we will calculate experience lens + 1 and mask out anything after that
|
||||
# for terminated, will just mask out any after lens
|
||||
|
||||
# if not supplied, assume truncated (which is the case for games like minecraft or capped timesteps)
|
||||
# mask out anything after the `lens`, which may include a bootstrapped node at the very end if `is_truncated = True`
|
||||
|
||||
if not exists(experience.is_truncated):
|
||||
experience.is_truncated = full((batch,), True, device = latents.device)
|
||||
|
||||
lens_for_gae_calc = torch.where(experience.is_truncated, experience.lens, experience.lens + 1)
|
||||
mask_for_gae = lens_to_mask(lens_for_gae_calc, time)
|
||||
mask_for_gae = lens_to_mask(experience.lens, time)
|
||||
|
||||
rewards = rewards.masked_fill(mask_for_gae, 0.)
|
||||
old_values = old_values.masked_fill(mask_for_gae, 0.)
|
||||
@ -2263,8 +2259,8 @@ class DynamicsWorldModel(Module):
|
||||
is_var_len = exists(experience.lens)
|
||||
|
||||
if is_var_len:
|
||||
lens = experience.lens
|
||||
mask = lens_to_mask(lens, max_time)
|
||||
learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
|
||||
mask = lens_to_mask(learnable_lens, max_time)
|
||||
|
||||
# determine whether to finetune entire transformer or just learn the heads
|
||||
|
||||
|
||||
@ -287,7 +287,7 @@ class DreamTrainer(Module):
|
||||
for _ in range(self.num_train_steps):
|
||||
|
||||
dreams = self.unwrapped_model.generate(
|
||||
self.generate_timesteps,
|
||||
self.generate_timesteps + 1, # plus one for bootstrap value
|
||||
batch_size = self.batch_size,
|
||||
return_rewards_per_frame = True,
|
||||
return_agent_actions = True,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.78"
|
||||
version = "0.0.79"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user