introduce an is_truncated field on Experience, and mask out rewards and values before calculating gae appropriately

This commit is contained in:
lucidrains 2025-10-27 07:55:00 -07:00
parent fbfd59e42f
commit 59c458aea3
2 changed files with 42 additions and 11 deletions

View File

@ -13,7 +13,7 @@ import torch.nn.functional as F
from torch.nested import nested_tensor
from torch.distributions import Normal
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
from torch.utils._pytree import tree_flatten, tree_unflatten
import torchvision
@ -83,6 +83,7 @@ class Experience:
values: Tensor | None = None
step_size: int | None = None
lens: Tensor | None = None
is_truncated: Tensor | None = None
agent_index: int = 0
is_from_world_model: bool = True
@ -99,7 +100,10 @@ def combine_experiences(
batch, time, device = *latents.shape[:2], latents.device
if not exists(exp.lens):
exp.lens = torch.full((batch,), time, device = device)
exp.lens = full((batch,), time, device = device)
if not exists(exp.is_truncated):
exp.is_truncated = full((batch,), True, device = device)
# convert to dictionary
@ -2115,7 +2119,7 @@ class DynamicsWorldModel(Module):
time_kv_cache = None
for _ in range(max_timesteps):
for i in range(max_timesteps + 1):
latents = self.video_tokenizer(video, return_latents = True)
@ -2141,6 +2145,15 @@ class DynamicsWorldModel(Module):
one_agent_embed = agent_embed[..., -1:, agent_index, :]
# values
value_bins = self.value_head(one_agent_embed)
value = self.reward_encoder.bins_to_scalar_value(value_bins)
values = safe_cat((values, value), dim = 1)
# policy embed
policy_embed = self.policy_head(one_agent_embed)
# sample actions
@ -2162,11 +2175,6 @@ class DynamicsWorldModel(Module):
discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
value_bins = self.value_head(one_agent_embed)
value = self.reward_encoder.bins_to_scalar_value(value_bins)
values = safe_cat((values, value), dim = 1)
# pass the sampled action to the environment and get back next state and reward
next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
@ -2187,6 +2195,8 @@ class DynamicsWorldModel(Module):
# package up one experience for learning
batch, device = latents.shape[0], latents.device
one_experience = Experience(
latents = latents,
video = video[:, :, :-1],
@ -2196,6 +2206,7 @@ class DynamicsWorldModel(Module):
values = values,
step_size = step_size,
agent_index = agent_index,
lens = full((batch,), max_timesteps, device = device),
is_from_world_model = False
)
@ -2224,6 +2235,26 @@ class DynamicsWorldModel(Module):
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
batch, time = latents.shape[0], latents.shape[1]
# calculate returns
# for truncated (true by default), we will calculate experience lens + 1 and mask out anything after that
# for terminated, will just mask out any after lens
# if not supplied, assume truncated (which is the case for games like minecraft or capped timesteps)
if not exists(experience.is_truncated):
experience.is_truncated = full((batch,), True, device = latents.device)
lens_for_gae_calc = torch.where(experience.is_truncated, experience.lens, experience.lens + 1)
mask_for_gae = lens_to_mask(lens_for_gae_calc, time)
rewards = rewards.masked_fill(mask_for_gae, 0.)
old_values = old_values.masked_fill(mask_for_gae, 0.)
# calculate returns
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
# handle variable lengths
@ -2387,7 +2418,7 @@ class DynamicsWorldModel(Module):
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
if isinstance(tasks, int):
tasks = torch.full((batch_size,), tasks, device = self.device)
tasks = full((batch_size,), tasks, device = self.device)
assert not exists(tasks) or tasks.shape[0] == batch_size
@ -2624,7 +2655,7 @@ class DynamicsWorldModel(Module):
# returning agent actions, rewards, and log probs + values for policy optimization
batch, device = latents.shape[0], latents.device
experience_lens = torch.full((batch,), time_steps, device = device)
experience_lens = full((batch,), time_steps, device = device)
gen = Experience(
latents = latents,

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.77"
version = "0.0.78"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }