From 59c458aea3f40912e12d768d9ebf77d29cd818fe Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 27 Oct 2025 07:55:00 -0700 Subject: [PATCH] introduce an `is_truncated` field on Experience, and mask out rewards and values before calculating gae appropriately --- dreamer4/dreamer4.py | 51 +++++++++++++++++++++++++++++++++++--------- pyproject.toml | 2 +- 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 5e0f61c..0748113 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from torch.nested import nested_tensor from torch.distributions import Normal from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity -from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange +from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange from torch.utils._pytree import tree_flatten, tree_unflatten import torchvision @@ -83,6 +83,7 @@ class Experience: values: Tensor | None = None step_size: int | None = None lens: Tensor | None = None + is_truncated: Tensor | None = None agent_index: int = 0 is_from_world_model: bool = True @@ -99,7 +100,10 @@ def combine_experiences( batch, time, device = *latents.shape[:2], latents.device if not exists(exp.lens): - exp.lens = torch.full((batch,), time, device = device) + exp.lens = full((batch,), time, device = device) + + if not exists(exp.is_truncated): + exp.is_truncated = full((batch,), True, device = device) # convert to dictionary @@ -2115,7 +2119,7 @@ class DynamicsWorldModel(Module): time_kv_cache = None - for _ in range(max_timesteps): + for i in range(max_timesteps + 1): latents = self.video_tokenizer(video, return_latents = True) @@ -2141,6 +2145,15 @@ class DynamicsWorldModel(Module): one_agent_embed = agent_embed[..., -1:, agent_index, :] + # values + + value_bins = self.value_head(one_agent_embed) + value = self.reward_encoder.bins_to_scalar_value(value_bins) + + values = safe_cat((values, value), dim = 1) + + # policy embed + policy_embed = self.policy_head(one_agent_embed) # sample actions @@ -2162,11 +2175,6 @@ class DynamicsWorldModel(Module): discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1) continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1) - value_bins = self.value_head(one_agent_embed) - value = self.reward_encoder.bins_to_scalar_value(value_bins) - - values = safe_cat((values, value), dim = 1) - # pass the sampled action to the environment and get back next state and reward next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions)) @@ -2187,6 +2195,8 @@ class DynamicsWorldModel(Module): # package up one experience for learning + batch, device = latents.shape[0], latents.device + one_experience = Experience( latents = latents, video = video[:, :, :-1], @@ -2196,6 +2206,7 @@ class DynamicsWorldModel(Module): values = values, step_size = step_size, agent_index = agent_index, + lens = full((batch,), max_timesteps, device = device), is_from_world_model = False ) @@ -2224,6 +2235,26 @@ class DynamicsWorldModel(Module): assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization' + batch, time = latents.shape[0], latents.shape[1] + + # calculate returns + + # for truncated (true by default), we will calculate experience lens + 1 and mask out anything after that + # for terminated, will just mask out any after lens + + # if not supplied, assume truncated (which is the case for games like minecraft or capped timesteps) + + if not exists(experience.is_truncated): + experience.is_truncated = full((batch,), True, device = latents.device) + + lens_for_gae_calc = torch.where(experience.is_truncated, experience.lens, experience.lens + 1) + mask_for_gae = lens_to_mask(lens_for_gae_calc, time) + + rewards = rewards.masked_fill(mask_for_gae, 0.) + old_values = old_values.masked_fill(mask_for_gae, 0.) + + # calculate returns + returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated) # handle variable lengths @@ -2387,7 +2418,7 @@ class DynamicsWorldModel(Module): assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}' if isinstance(tasks, int): - tasks = torch.full((batch_size,), tasks, device = self.device) + tasks = full((batch_size,), tasks, device = self.device) assert not exists(tasks) or tasks.shape[0] == batch_size @@ -2624,7 +2655,7 @@ class DynamicsWorldModel(Module): # returning agent actions, rewards, and log probs + values for policy optimization batch, device = latents.shape[0], latents.device - experience_lens = torch.full((batch,), time_steps, device = device) + experience_lens = full((batch,), time_steps, device = device) gen = Experience( latents = latents, diff --git a/pyproject.toml b/pyproject.toml index bb2ff1a..56ded6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.77" +version = "0.0.78" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }