From 187edc1414da2aeae48d3de946fba3af1565986c Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 8 Oct 2025 06:33:28 -0700 Subject: [PATCH] all set for generating the perceived rewards once the RL components fall into place --- dreamer4/dreamer4.py | 71 +++++++++++++++++++++++++++++-------------- pyproject.toml | 2 +- tests/test_dreamer.py | 6 ++-- 3 files changed, 54 insertions(+), 25 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index c174ad2..71dd3c6 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -7,10 +7,9 @@ from collections import namedtuple from functools import partial import torch -from torch import nn import torch.nn.functional as F from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity -from torch import cat, stack, arange, tensor, Tensor, is_tensor +from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange import torchvision from torchvision.models import VGG16_Weights @@ -147,7 +146,7 @@ class LPIPSLoss(Module): # batch randperm - batch_randperm = torch.randn(pred.shape[:2], device = pred.device).argsort(dim = -1) + batch_randperm = randn(pred.shape[:2], device = pred.device).argsort(dim = -1) rand_frames = batch_randperm[..., :self.sampled_frames] batch_arange = arange(batch, device = device) @@ -182,7 +181,7 @@ class SymExpTwoHot(Module): super().__init__() min_value, max_value = range - values = torch.linspace(min_value, max_value, num_bins) + values = linspace(min_value, max_value, num_bins) values = values.sign() * (torch.exp(values.abs()) - 1.) self.num_bins = num_bins @@ -739,12 +738,12 @@ class VideoTokenizer(Module): assert num_latent_tokens >= 1 self.num_latent_tokens = num_latent_tokens - self.latent_tokens = Parameter(torch.randn(num_latent_tokens, dim) * 1e-2) + self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2) # mae masking - Kaiming He paper from long ago self.per_image_patch_mask_prob = per_image_patch_mask_prob - self.mask_token = Parameter(torch.randn(dim) * 1e-2) + self.mask_token = Parameter(randn(dim) * 1e-2) # patch and unpatch @@ -1255,7 +1254,9 @@ class DynamicsModel(Module): image_height = None, image_width = None, return_decoded_video = None, - context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc + context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc + return_rewards_per_frame = False + ): # (b t n d) | (b c t h w) assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2' @@ -1270,30 +1271,36 @@ class DynamicsModel(Module): # denoising # teacher forcing to start with - latents = torch.empty((batch_size, 0, *latent_shape), device = self.device) + latents = empty((batch_size, 0, *latent_shape), device = self.device) + + # maybe return rewards + + if return_rewards_per_frame: + decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32) # while all the frames of the video (per latent) is not generated while latents.shape[1] < time_steps: curr_time_steps = latents.shape[1] - noised_latent = torch.randn((batch_size, 1, *latent_shape), device = self.device) + noised_latent = randn((batch_size, 1, *latent_shape), device = self.device) for step in range(num_steps): - signal_levels = torch.full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device) + signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device) - noised_context = latents.lerp(torch.randn_like(latents), context_signal_noise) # the paragraph after eq (8) + noised_context = latents.lerp(randn_like(latents), context_signal_noise) # the paragraph after eq (8) noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d') signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1) - pred = self.forward( + pred, agent_embed = self.forward( latents = noised_latent_with_context, signal_levels = signal_levels_with_context, step_sizes = step_size, latent_is_noised = True, - return_pred_only = True + return_pred_only = True, + return_agent_tokens = True ) _, pred = unpack(pred, pack_context_shape, 'b * n d') @@ -1310,6 +1317,16 @@ class DynamicsModel(Module): noised_latent += flow * (step_size / self.max_steps) + # take care of the rewards by predicting on the agent token embedding on the last denoising step + + if return_rewards_per_frame: + reward_logits = self.to_reward_pred(agent_embed[:, -1:]) + pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True) + + decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1) + + # concat the denoised latent + latents = cat((latents, noised_latent), dim = 1) # returning video @@ -1318,7 +1335,10 @@ class DynamicsModel(Module): return_decoded_video = default(return_decoded_video, has_tokenizer) if not return_decoded_video: - return denoised_latents + if not return_rewards_per_frame: + return denoised_latents + + return denoised_latents, decoded_rewards generated_video = self.video_tokenizer.decode( latents, @@ -1326,7 +1346,10 @@ class DynamicsModel(Module): width = image_width ) - return generated_video + if not return_rewards_per_frame: + return generated_video + + return generated_video, decoded_rewards def forward( self, @@ -1340,7 +1363,8 @@ class DynamicsModel(Module): rewards = None, # (b t) return_pred_only = False, latent_is_noised = False, - return_all_losses = False + return_all_losses = False, + return_agent_tokens = False ): # handle video or latents @@ -1416,16 +1440,16 @@ class DynamicsModel(Module): # if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min # in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this - step_sizes_log2 = torch.zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1 - signal_levels = torch.randint(0, self.max_steps, (batch, time), device = device) + step_sizes_log2 = zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1 + signal_levels = randint(0, self.max_steps, (batch, time), device = device) else: # now we follow eq (4) - step_sizes_log2 = torch.randint(1, self.num_step_sizes_log2, (batch,), device = device) + step_sizes_log2 = randint(1, self.num_step_sizes_log2, (batch,), device = device) num_step_sizes = 2 ** step_sizes_log2 - signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes + signal_levels = randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes # times is from 0 to 1 @@ -1434,7 +1458,7 @@ class DynamicsModel(Module): if not latent_is_noised: # get the noise - noise = torch.randn_like(latents) + noise = randn_like(latents) # noise from 0 as noise to 1 as data @@ -1566,7 +1590,10 @@ class DynamicsModel(Module): pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True) if return_pred_only: - return pred + if not return_agent_tokens: + return pred + + return pred, encoded_agent_tokens # determine the target for the loss diff --git a/pyproject.toml b/pyproject.toml index 58b8a7c..8cc7f77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.5" +version = "0.0.6" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index bebff04..671f6dc 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -90,14 +90,16 @@ def test_e2e( # generating - generated_video = dynamics.generate( + generated_video, generated_rewards = dynamics.generate( time_steps = 10, image_height = 128, image_width = 128, - batch_size = 2 + batch_size = 2, + return_rewards_per_frame = True ) assert generated_video.shape == (2, 3, 10, 128, 128) + assert generated_rewards.shape == (2, 10) # rl