all set for generating the perceived rewards once the RL components fall into place

This commit is contained in:
lucidrains 2025-10-08 06:33:28 -07:00
parent f7bdaddbbb
commit 187edc1414
3 changed files with 54 additions and 25 deletions

View File

@ -7,10 +7,9 @@ from collections import namedtuple
from functools import partial
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import cat, stack, arange, tensor, Tensor, is_tensor
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
import torchvision
from torchvision.models import VGG16_Weights
@ -147,7 +146,7 @@ class LPIPSLoss(Module):
# batch randperm
batch_randperm = torch.randn(pred.shape[:2], device = pred.device).argsort(dim = -1)
batch_randperm = randn(pred.shape[:2], device = pred.device).argsort(dim = -1)
rand_frames = batch_randperm[..., :self.sampled_frames]
batch_arange = arange(batch, device = device)
@ -182,7 +181,7 @@ class SymExpTwoHot(Module):
super().__init__()
min_value, max_value = range
values = torch.linspace(min_value, max_value, num_bins)
values = linspace(min_value, max_value, num_bins)
values = values.sign() * (torch.exp(values.abs()) - 1.)
self.num_bins = num_bins
@ -739,12 +738,12 @@ class VideoTokenizer(Module):
assert num_latent_tokens >= 1
self.num_latent_tokens = num_latent_tokens
self.latent_tokens = Parameter(torch.randn(num_latent_tokens, dim) * 1e-2)
self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2)
# mae masking - Kaiming He paper from long ago
self.per_image_patch_mask_prob = per_image_patch_mask_prob
self.mask_token = Parameter(torch.randn(dim) * 1e-2)
self.mask_token = Parameter(randn(dim) * 1e-2)
# patch and unpatch
@ -1255,7 +1254,9 @@ class DynamicsModel(Module):
image_height = None,
image_width = None,
return_decoded_video = None,
context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
return_rewards_per_frame = False
): # (b t n d) | (b c t h w)
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
@ -1270,30 +1271,36 @@ class DynamicsModel(Module):
# denoising
# teacher forcing to start with
latents = torch.empty((batch_size, 0, *latent_shape), device = self.device)
latents = empty((batch_size, 0, *latent_shape), device = self.device)
# maybe return rewards
if return_rewards_per_frame:
decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32)
# while all the frames of the video (per latent) is not generated
while latents.shape[1] < time_steps:
curr_time_steps = latents.shape[1]
noised_latent = torch.randn((batch_size, 1, *latent_shape), device = self.device)
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
for step in range(num_steps):
signal_levels = torch.full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
noised_context = latents.lerp(torch.randn_like(latents), context_signal_noise) # the paragraph after eq (8)
noised_context = latents.lerp(randn_like(latents), context_signal_noise) # the paragraph after eq (8)
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
pred = self.forward(
pred, agent_embed = self.forward(
latents = noised_latent_with_context,
signal_levels = signal_levels_with_context,
step_sizes = step_size,
latent_is_noised = True,
return_pred_only = True
return_pred_only = True,
return_agent_tokens = True
)
_, pred = unpack(pred, pack_context_shape, 'b * n d')
@ -1310,6 +1317,16 @@ class DynamicsModel(Module):
noised_latent += flow * (step_size / self.max_steps)
# take care of the rewards by predicting on the agent token embedding on the last denoising step
if return_rewards_per_frame:
reward_logits = self.to_reward_pred(agent_embed[:, -1:])
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
# concat the denoised latent
latents = cat((latents, noised_latent), dim = 1)
# returning video
@ -1318,7 +1335,10 @@ class DynamicsModel(Module):
return_decoded_video = default(return_decoded_video, has_tokenizer)
if not return_decoded_video:
return denoised_latents
if not return_rewards_per_frame:
return denoised_latents
return denoised_latents, decoded_rewards
generated_video = self.video_tokenizer.decode(
latents,
@ -1326,7 +1346,10 @@ class DynamicsModel(Module):
width = image_width
)
return generated_video
if not return_rewards_per_frame:
return generated_video
return generated_video, decoded_rewards
def forward(
self,
@ -1340,7 +1363,8 @@ class DynamicsModel(Module):
rewards = None, # (b t)
return_pred_only = False,
latent_is_noised = False,
return_all_losses = False
return_all_losses = False,
return_agent_tokens = False
):
# handle video or latents
@ -1416,16 +1440,16 @@ class DynamicsModel(Module):
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
# in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this
step_sizes_log2 = torch.zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
signal_levels = torch.randint(0, self.max_steps, (batch, time), device = device)
step_sizes_log2 = zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
signal_levels = randint(0, self.max_steps, (batch, time), device = device)
else:
# now we follow eq (4)
step_sizes_log2 = torch.randint(1, self.num_step_sizes_log2, (batch,), device = device)
step_sizes_log2 = randint(1, self.num_step_sizes_log2, (batch,), device = device)
num_step_sizes = 2 ** step_sizes_log2
signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
signal_levels = randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
# times is from 0 to 1
@ -1434,7 +1458,7 @@ class DynamicsModel(Module):
if not latent_is_noised:
# get the noise
noise = torch.randn_like(latents)
noise = randn_like(latents)
# noise from 0 as noise to 1 as data
@ -1566,7 +1590,10 @@ class DynamicsModel(Module):
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True)
if return_pred_only:
return pred
if not return_agent_tokens:
return pred
return pred, encoded_agent_tokens
# determine the target for the loss

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.5"
version = "0.0.6"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -90,14 +90,16 @@ def test_e2e(
# generating
generated_video = dynamics.generate(
generated_video, generated_rewards = dynamics.generate(
time_steps = 10,
image_height = 128,
image_width = 128,
batch_size = 2
batch_size = 2,
return_rewards_per_frame = True
)
assert generated_video.shape == (2, 3, 10, 128, 128)
assert generated_rewards.shape == (2, 10)
# rl