all set for generating the perceived rewards once the RL components fall into place
This commit is contained in:
parent
f7bdaddbbb
commit
187edc1414
@ -7,10 +7,9 @@ from collections import namedtuple
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||
from torch import cat, stack, arange, tensor, Tensor, is_tensor
|
||||
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
||||
|
||||
import torchvision
|
||||
from torchvision.models import VGG16_Weights
|
||||
@ -147,7 +146,7 @@ class LPIPSLoss(Module):
|
||||
|
||||
# batch randperm
|
||||
|
||||
batch_randperm = torch.randn(pred.shape[:2], device = pred.device).argsort(dim = -1)
|
||||
batch_randperm = randn(pred.shape[:2], device = pred.device).argsort(dim = -1)
|
||||
rand_frames = batch_randperm[..., :self.sampled_frames]
|
||||
|
||||
batch_arange = arange(batch, device = device)
|
||||
@ -182,7 +181,7 @@ class SymExpTwoHot(Module):
|
||||
super().__init__()
|
||||
|
||||
min_value, max_value = range
|
||||
values = torch.linspace(min_value, max_value, num_bins)
|
||||
values = linspace(min_value, max_value, num_bins)
|
||||
values = values.sign() * (torch.exp(values.abs()) - 1.)
|
||||
|
||||
self.num_bins = num_bins
|
||||
@ -739,12 +738,12 @@ class VideoTokenizer(Module):
|
||||
|
||||
assert num_latent_tokens >= 1
|
||||
self.num_latent_tokens = num_latent_tokens
|
||||
self.latent_tokens = Parameter(torch.randn(num_latent_tokens, dim) * 1e-2)
|
||||
self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2)
|
||||
|
||||
# mae masking - Kaiming He paper from long ago
|
||||
|
||||
self.per_image_patch_mask_prob = per_image_patch_mask_prob
|
||||
self.mask_token = Parameter(torch.randn(dim) * 1e-2)
|
||||
self.mask_token = Parameter(randn(dim) * 1e-2)
|
||||
|
||||
# patch and unpatch
|
||||
|
||||
@ -1255,7 +1254,9 @@ class DynamicsModel(Module):
|
||||
image_height = None,
|
||||
image_width = None,
|
||||
return_decoded_video = None,
|
||||
context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
||||
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
||||
return_rewards_per_frame = False
|
||||
|
||||
): # (b t n d) | (b c t h w)
|
||||
|
||||
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
||||
@ -1270,30 +1271,36 @@ class DynamicsModel(Module):
|
||||
# denoising
|
||||
# teacher forcing to start with
|
||||
|
||||
latents = torch.empty((batch_size, 0, *latent_shape), device = self.device)
|
||||
latents = empty((batch_size, 0, *latent_shape), device = self.device)
|
||||
|
||||
# maybe return rewards
|
||||
|
||||
if return_rewards_per_frame:
|
||||
decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32)
|
||||
|
||||
# while all the frames of the video (per latent) is not generated
|
||||
|
||||
while latents.shape[1] < time_steps:
|
||||
|
||||
curr_time_steps = latents.shape[1]
|
||||
noised_latent = torch.randn((batch_size, 1, *latent_shape), device = self.device)
|
||||
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
|
||||
|
||||
for step in range(num_steps):
|
||||
signal_levels = torch.full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
||||
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
||||
|
||||
noised_context = latents.lerp(torch.randn_like(latents), context_signal_noise) # the paragraph after eq (8)
|
||||
noised_context = latents.lerp(randn_like(latents), context_signal_noise) # the paragraph after eq (8)
|
||||
|
||||
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
||||
|
||||
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
||||
|
||||
pred = self.forward(
|
||||
pred, agent_embed = self.forward(
|
||||
latents = noised_latent_with_context,
|
||||
signal_levels = signal_levels_with_context,
|
||||
step_sizes = step_size,
|
||||
latent_is_noised = True,
|
||||
return_pred_only = True
|
||||
return_pred_only = True,
|
||||
return_agent_tokens = True
|
||||
)
|
||||
|
||||
_, pred = unpack(pred, pack_context_shape, 'b * n d')
|
||||
@ -1310,6 +1317,16 @@ class DynamicsModel(Module):
|
||||
|
||||
noised_latent += flow * (step_size / self.max_steps)
|
||||
|
||||
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
||||
|
||||
if return_rewards_per_frame:
|
||||
reward_logits = self.to_reward_pred(agent_embed[:, -1:])
|
||||
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
|
||||
|
||||
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
||||
|
||||
# concat the denoised latent
|
||||
|
||||
latents = cat((latents, noised_latent), dim = 1)
|
||||
|
||||
# returning video
|
||||
@ -1318,7 +1335,10 @@ class DynamicsModel(Module):
|
||||
return_decoded_video = default(return_decoded_video, has_tokenizer)
|
||||
|
||||
if not return_decoded_video:
|
||||
return denoised_latents
|
||||
if not return_rewards_per_frame:
|
||||
return denoised_latents
|
||||
|
||||
return denoised_latents, decoded_rewards
|
||||
|
||||
generated_video = self.video_tokenizer.decode(
|
||||
latents,
|
||||
@ -1326,7 +1346,10 @@ class DynamicsModel(Module):
|
||||
width = image_width
|
||||
)
|
||||
|
||||
return generated_video
|
||||
if not return_rewards_per_frame:
|
||||
return generated_video
|
||||
|
||||
return generated_video, decoded_rewards
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1340,7 +1363,8 @@ class DynamicsModel(Module):
|
||||
rewards = None, # (b t)
|
||||
return_pred_only = False,
|
||||
latent_is_noised = False,
|
||||
return_all_losses = False
|
||||
return_all_losses = False,
|
||||
return_agent_tokens = False
|
||||
):
|
||||
# handle video or latents
|
||||
|
||||
@ -1416,16 +1440,16 @@ class DynamicsModel(Module):
|
||||
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
|
||||
# in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this
|
||||
|
||||
step_sizes_log2 = torch.zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
|
||||
signal_levels = torch.randint(0, self.max_steps, (batch, time), device = device)
|
||||
step_sizes_log2 = zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
|
||||
signal_levels = randint(0, self.max_steps, (batch, time), device = device)
|
||||
else:
|
||||
|
||||
# now we follow eq (4)
|
||||
|
||||
step_sizes_log2 = torch.randint(1, self.num_step_sizes_log2, (batch,), device = device)
|
||||
step_sizes_log2 = randint(1, self.num_step_sizes_log2, (batch,), device = device)
|
||||
num_step_sizes = 2 ** step_sizes_log2
|
||||
|
||||
signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
|
||||
signal_levels = randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
|
||||
|
||||
# times is from 0 to 1
|
||||
|
||||
@ -1434,7 +1458,7 @@ class DynamicsModel(Module):
|
||||
if not latent_is_noised:
|
||||
# get the noise
|
||||
|
||||
noise = torch.randn_like(latents)
|
||||
noise = randn_like(latents)
|
||||
|
||||
# noise from 0 as noise to 1 as data
|
||||
|
||||
@ -1566,7 +1590,10 @@ class DynamicsModel(Module):
|
||||
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True)
|
||||
|
||||
if return_pred_only:
|
||||
return pred
|
||||
if not return_agent_tokens:
|
||||
return pred
|
||||
|
||||
return pred, encoded_agent_tokens
|
||||
|
||||
# determine the target for the loss
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.5"
|
||||
version = "0.0.6"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -90,14 +90,16 @@ def test_e2e(
|
||||
|
||||
# generating
|
||||
|
||||
generated_video = dynamics.generate(
|
||||
generated_video, generated_rewards = dynamics.generate(
|
||||
time_steps = 10,
|
||||
image_height = 128,
|
||||
image_width = 128,
|
||||
batch_size = 2
|
||||
batch_size = 2,
|
||||
return_rewards_per_frame = True
|
||||
)
|
||||
|
||||
assert generated_video.shape == (2, 3, 10, 128, 128)
|
||||
assert generated_rewards.shape == (2, 10)
|
||||
|
||||
# rl
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user