all set for generating the perceived rewards once the RL components fall into place

This commit is contained in:
lucidrains 2025-10-08 06:33:28 -07:00
parent f7bdaddbbb
commit 187edc1414
3 changed files with 54 additions and 25 deletions

View File

@ -7,10 +7,9 @@ from collections import namedtuple
from functools import partial from functools import partial
import torch import torch
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import cat, stack, arange, tensor, Tensor, is_tensor from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
import torchvision import torchvision
from torchvision.models import VGG16_Weights from torchvision.models import VGG16_Weights
@ -147,7 +146,7 @@ class LPIPSLoss(Module):
# batch randperm # batch randperm
batch_randperm = torch.randn(pred.shape[:2], device = pred.device).argsort(dim = -1) batch_randperm = randn(pred.shape[:2], device = pred.device).argsort(dim = -1)
rand_frames = batch_randperm[..., :self.sampled_frames] rand_frames = batch_randperm[..., :self.sampled_frames]
batch_arange = arange(batch, device = device) batch_arange = arange(batch, device = device)
@ -182,7 +181,7 @@ class SymExpTwoHot(Module):
super().__init__() super().__init__()
min_value, max_value = range min_value, max_value = range
values = torch.linspace(min_value, max_value, num_bins) values = linspace(min_value, max_value, num_bins)
values = values.sign() * (torch.exp(values.abs()) - 1.) values = values.sign() * (torch.exp(values.abs()) - 1.)
self.num_bins = num_bins self.num_bins = num_bins
@ -739,12 +738,12 @@ class VideoTokenizer(Module):
assert num_latent_tokens >= 1 assert num_latent_tokens >= 1
self.num_latent_tokens = num_latent_tokens self.num_latent_tokens = num_latent_tokens
self.latent_tokens = Parameter(torch.randn(num_latent_tokens, dim) * 1e-2) self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2)
# mae masking - Kaiming He paper from long ago # mae masking - Kaiming He paper from long ago
self.per_image_patch_mask_prob = per_image_patch_mask_prob self.per_image_patch_mask_prob = per_image_patch_mask_prob
self.mask_token = Parameter(torch.randn(dim) * 1e-2) self.mask_token = Parameter(randn(dim) * 1e-2)
# patch and unpatch # patch and unpatch
@ -1255,7 +1254,9 @@ class DynamicsModel(Module):
image_height = None, image_height = None,
image_width = None, image_width = None,
return_decoded_video = None, return_decoded_video = None,
context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
return_rewards_per_frame = False
): # (b t n d) | (b c t h w) ): # (b t n d) | (b c t h w)
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2' assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
@ -1270,30 +1271,36 @@ class DynamicsModel(Module):
# denoising # denoising
# teacher forcing to start with # teacher forcing to start with
latents = torch.empty((batch_size, 0, *latent_shape), device = self.device) latents = empty((batch_size, 0, *latent_shape), device = self.device)
# maybe return rewards
if return_rewards_per_frame:
decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32)
# while all the frames of the video (per latent) is not generated # while all the frames of the video (per latent) is not generated
while latents.shape[1] < time_steps: while latents.shape[1] < time_steps:
curr_time_steps = latents.shape[1] curr_time_steps = latents.shape[1]
noised_latent = torch.randn((batch_size, 1, *latent_shape), device = self.device) noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
for step in range(num_steps): for step in range(num_steps):
signal_levels = torch.full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device) signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
noised_context = latents.lerp(torch.randn_like(latents), context_signal_noise) # the paragraph after eq (8) noised_context = latents.lerp(randn_like(latents), context_signal_noise) # the paragraph after eq (8)
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d') noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1) signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
pred = self.forward( pred, agent_embed = self.forward(
latents = noised_latent_with_context, latents = noised_latent_with_context,
signal_levels = signal_levels_with_context, signal_levels = signal_levels_with_context,
step_sizes = step_size, step_sizes = step_size,
latent_is_noised = True, latent_is_noised = True,
return_pred_only = True return_pred_only = True,
return_agent_tokens = True
) )
_, pred = unpack(pred, pack_context_shape, 'b * n d') _, pred = unpack(pred, pack_context_shape, 'b * n d')
@ -1310,6 +1317,16 @@ class DynamicsModel(Module):
noised_latent += flow * (step_size / self.max_steps) noised_latent += flow * (step_size / self.max_steps)
# take care of the rewards by predicting on the agent token embedding on the last denoising step
if return_rewards_per_frame:
reward_logits = self.to_reward_pred(agent_embed[:, -1:])
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
# concat the denoised latent
latents = cat((latents, noised_latent), dim = 1) latents = cat((latents, noised_latent), dim = 1)
# returning video # returning video
@ -1318,7 +1335,10 @@ class DynamicsModel(Module):
return_decoded_video = default(return_decoded_video, has_tokenizer) return_decoded_video = default(return_decoded_video, has_tokenizer)
if not return_decoded_video: if not return_decoded_video:
return denoised_latents if not return_rewards_per_frame:
return denoised_latents
return denoised_latents, decoded_rewards
generated_video = self.video_tokenizer.decode( generated_video = self.video_tokenizer.decode(
latents, latents,
@ -1326,7 +1346,10 @@ class DynamicsModel(Module):
width = image_width width = image_width
) )
return generated_video if not return_rewards_per_frame:
return generated_video
return generated_video, decoded_rewards
def forward( def forward(
self, self,
@ -1340,7 +1363,8 @@ class DynamicsModel(Module):
rewards = None, # (b t) rewards = None, # (b t)
return_pred_only = False, return_pred_only = False,
latent_is_noised = False, latent_is_noised = False,
return_all_losses = False return_all_losses = False,
return_agent_tokens = False
): ):
# handle video or latents # handle video or latents
@ -1416,16 +1440,16 @@ class DynamicsModel(Module):
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min # if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
# in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this # in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this
step_sizes_log2 = torch.zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1 step_sizes_log2 = zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
signal_levels = torch.randint(0, self.max_steps, (batch, time), device = device) signal_levels = randint(0, self.max_steps, (batch, time), device = device)
else: else:
# now we follow eq (4) # now we follow eq (4)
step_sizes_log2 = torch.randint(1, self.num_step_sizes_log2, (batch,), device = device) step_sizes_log2 = randint(1, self.num_step_sizes_log2, (batch,), device = device)
num_step_sizes = 2 ** step_sizes_log2 num_step_sizes = 2 ** step_sizes_log2
signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes signal_levels = randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
# times is from 0 to 1 # times is from 0 to 1
@ -1434,7 +1458,7 @@ class DynamicsModel(Module):
if not latent_is_noised: if not latent_is_noised:
# get the noise # get the noise
noise = torch.randn_like(latents) noise = randn_like(latents)
# noise from 0 as noise to 1 as data # noise from 0 as noise to 1 as data
@ -1566,7 +1590,10 @@ class DynamicsModel(Module):
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True) pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True)
if return_pred_only: if return_pred_only:
return pred if not return_agent_tokens:
return pred
return pred, encoded_agent_tokens
# determine the target for the loss # determine the target for the loss

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.5" version = "0.0.6"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -90,14 +90,16 @@ def test_e2e(
# generating # generating
generated_video = dynamics.generate( generated_video, generated_rewards = dynamics.generate(
time_steps = 10, time_steps = 10,
image_height = 128, image_height = 128,
image_width = 128, image_width = 128,
batch_size = 2 batch_size = 2,
return_rewards_per_frame = True
) )
assert generated_video.shape == (2, 3, 10, 128, 128) assert generated_video.shape == (2, 3, 10, 128, 128)
assert generated_rewards.shape == (2, 10)
# rl # rl