introduce an is_truncated field on Experience, and mask out rewards and values before calculating gae appropriately

This commit is contained in:
lucidrains 2025-10-27 07:55:00 -07:00
parent fbfd59e42f
commit 59c458aea3
2 changed files with 42 additions and 11 deletions

View File

@ -13,7 +13,7 @@ import torch.nn.functional as F
from torch.nested import nested_tensor from torch.nested import nested_tensor
from torch.distributions import Normal from torch.distributions import Normal
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
import torchvision import torchvision
@ -83,6 +83,7 @@ class Experience:
values: Tensor | None = None values: Tensor | None = None
step_size: int | None = None step_size: int | None = None
lens: Tensor | None = None lens: Tensor | None = None
is_truncated: Tensor | None = None
agent_index: int = 0 agent_index: int = 0
is_from_world_model: bool = True is_from_world_model: bool = True
@ -99,7 +100,10 @@ def combine_experiences(
batch, time, device = *latents.shape[:2], latents.device batch, time, device = *latents.shape[:2], latents.device
if not exists(exp.lens): if not exists(exp.lens):
exp.lens = torch.full((batch,), time, device = device) exp.lens = full((batch,), time, device = device)
if not exists(exp.is_truncated):
exp.is_truncated = full((batch,), True, device = device)
# convert to dictionary # convert to dictionary
@ -2115,7 +2119,7 @@ class DynamicsWorldModel(Module):
time_kv_cache = None time_kv_cache = None
for _ in range(max_timesteps): for i in range(max_timesteps + 1):
latents = self.video_tokenizer(video, return_latents = True) latents = self.video_tokenizer(video, return_latents = True)
@ -2141,6 +2145,15 @@ class DynamicsWorldModel(Module):
one_agent_embed = agent_embed[..., -1:, agent_index, :] one_agent_embed = agent_embed[..., -1:, agent_index, :]
# values
value_bins = self.value_head(one_agent_embed)
value = self.reward_encoder.bins_to_scalar_value(value_bins)
values = safe_cat((values, value), dim = 1)
# policy embed
policy_embed = self.policy_head(one_agent_embed) policy_embed = self.policy_head(one_agent_embed)
# sample actions # sample actions
@ -2162,11 +2175,6 @@ class DynamicsWorldModel(Module):
discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1) discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1) continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
value_bins = self.value_head(one_agent_embed)
value = self.reward_encoder.bins_to_scalar_value(value_bins)
values = safe_cat((values, value), dim = 1)
# pass the sampled action to the environment and get back next state and reward # pass the sampled action to the environment and get back next state and reward
next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions)) next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
@ -2187,6 +2195,8 @@ class DynamicsWorldModel(Module):
# package up one experience for learning # package up one experience for learning
batch, device = latents.shape[0], latents.device
one_experience = Experience( one_experience = Experience(
latents = latents, latents = latents,
video = video[:, :, :-1], video = video[:, :, :-1],
@ -2196,6 +2206,7 @@ class DynamicsWorldModel(Module):
values = values, values = values,
step_size = step_size, step_size = step_size,
agent_index = agent_index, agent_index = agent_index,
lens = full((batch,), max_timesteps, device = device),
is_from_world_model = False is_from_world_model = False
) )
@ -2224,6 +2235,26 @@ class DynamicsWorldModel(Module):
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization' assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
batch, time = latents.shape[0], latents.shape[1]
# calculate returns
# for truncated (true by default), we will calculate experience lens + 1 and mask out anything after that
# for terminated, will just mask out any after lens
# if not supplied, assume truncated (which is the case for games like minecraft or capped timesteps)
if not exists(experience.is_truncated):
experience.is_truncated = full((batch,), True, device = latents.device)
lens_for_gae_calc = torch.where(experience.is_truncated, experience.lens, experience.lens + 1)
mask_for_gae = lens_to_mask(lens_for_gae_calc, time)
rewards = rewards.masked_fill(mask_for_gae, 0.)
old_values = old_values.masked_fill(mask_for_gae, 0.)
# calculate returns
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated) returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
# handle variable lengths # handle variable lengths
@ -2387,7 +2418,7 @@ class DynamicsWorldModel(Module):
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}' assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
if isinstance(tasks, int): if isinstance(tasks, int):
tasks = torch.full((batch_size,), tasks, device = self.device) tasks = full((batch_size,), tasks, device = self.device)
assert not exists(tasks) or tasks.shape[0] == batch_size assert not exists(tasks) or tasks.shape[0] == batch_size
@ -2624,7 +2655,7 @@ class DynamicsWorldModel(Module):
# returning agent actions, rewards, and log probs + values for policy optimization # returning agent actions, rewards, and log probs + values for policy optimization
batch, device = latents.shape[0], latents.device batch, device = latents.shape[0], latents.device
experience_lens = torch.full((batch,), time_steps, device = device) experience_lens = full((batch,), time_steps, device = device)
gen = Experience( gen = Experience(
latents = latents, latents = latents,

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.77" version = "0.0.78"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }