all set for generating the perceived rewards once the RL components fall into place
This commit is contained in:
parent
f7bdaddbbb
commit
187edc1414
@ -7,10 +7,9 @@ from collections import namedtuple
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||||
from torch import cat, stack, arange, tensor, Tensor, is_tensor
|
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
from torchvision.models import VGG16_Weights
|
from torchvision.models import VGG16_Weights
|
||||||
@ -147,7 +146,7 @@ class LPIPSLoss(Module):
|
|||||||
|
|
||||||
# batch randperm
|
# batch randperm
|
||||||
|
|
||||||
batch_randperm = torch.randn(pred.shape[:2], device = pred.device).argsort(dim = -1)
|
batch_randperm = randn(pred.shape[:2], device = pred.device).argsort(dim = -1)
|
||||||
rand_frames = batch_randperm[..., :self.sampled_frames]
|
rand_frames = batch_randperm[..., :self.sampled_frames]
|
||||||
|
|
||||||
batch_arange = arange(batch, device = device)
|
batch_arange = arange(batch, device = device)
|
||||||
@ -182,7 +181,7 @@ class SymExpTwoHot(Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
min_value, max_value = range
|
min_value, max_value = range
|
||||||
values = torch.linspace(min_value, max_value, num_bins)
|
values = linspace(min_value, max_value, num_bins)
|
||||||
values = values.sign() * (torch.exp(values.abs()) - 1.)
|
values = values.sign() * (torch.exp(values.abs()) - 1.)
|
||||||
|
|
||||||
self.num_bins = num_bins
|
self.num_bins = num_bins
|
||||||
@ -739,12 +738,12 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
assert num_latent_tokens >= 1
|
assert num_latent_tokens >= 1
|
||||||
self.num_latent_tokens = num_latent_tokens
|
self.num_latent_tokens = num_latent_tokens
|
||||||
self.latent_tokens = Parameter(torch.randn(num_latent_tokens, dim) * 1e-2)
|
self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2)
|
||||||
|
|
||||||
# mae masking - Kaiming He paper from long ago
|
# mae masking - Kaiming He paper from long ago
|
||||||
|
|
||||||
self.per_image_patch_mask_prob = per_image_patch_mask_prob
|
self.per_image_patch_mask_prob = per_image_patch_mask_prob
|
||||||
self.mask_token = Parameter(torch.randn(dim) * 1e-2)
|
self.mask_token = Parameter(randn(dim) * 1e-2)
|
||||||
|
|
||||||
# patch and unpatch
|
# patch and unpatch
|
||||||
|
|
||||||
@ -1255,7 +1254,9 @@ class DynamicsModel(Module):
|
|||||||
image_height = None,
|
image_height = None,
|
||||||
image_width = None,
|
image_width = None,
|
||||||
return_decoded_video = None,
|
return_decoded_video = None,
|
||||||
context_signal_noise = 0.1 # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
||||||
|
return_rewards_per_frame = False
|
||||||
|
|
||||||
): # (b t n d) | (b c t h w)
|
): # (b t n d) | (b c t h w)
|
||||||
|
|
||||||
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
||||||
@ -1270,30 +1271,36 @@ class DynamicsModel(Module):
|
|||||||
# denoising
|
# denoising
|
||||||
# teacher forcing to start with
|
# teacher forcing to start with
|
||||||
|
|
||||||
latents = torch.empty((batch_size, 0, *latent_shape), device = self.device)
|
latents = empty((batch_size, 0, *latent_shape), device = self.device)
|
||||||
|
|
||||||
|
# maybe return rewards
|
||||||
|
|
||||||
|
if return_rewards_per_frame:
|
||||||
|
decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32)
|
||||||
|
|
||||||
# while all the frames of the video (per latent) is not generated
|
# while all the frames of the video (per latent) is not generated
|
||||||
|
|
||||||
while latents.shape[1] < time_steps:
|
while latents.shape[1] < time_steps:
|
||||||
|
|
||||||
curr_time_steps = latents.shape[1]
|
curr_time_steps = latents.shape[1]
|
||||||
noised_latent = torch.randn((batch_size, 1, *latent_shape), device = self.device)
|
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
|
||||||
|
|
||||||
for step in range(num_steps):
|
for step in range(num_steps):
|
||||||
signal_levels = torch.full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
||||||
|
|
||||||
noised_context = latents.lerp(torch.randn_like(latents), context_signal_noise) # the paragraph after eq (8)
|
noised_context = latents.lerp(randn_like(latents), context_signal_noise) # the paragraph after eq (8)
|
||||||
|
|
||||||
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
||||||
|
|
||||||
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
||||||
|
|
||||||
pred = self.forward(
|
pred, agent_embed = self.forward(
|
||||||
latents = noised_latent_with_context,
|
latents = noised_latent_with_context,
|
||||||
signal_levels = signal_levels_with_context,
|
signal_levels = signal_levels_with_context,
|
||||||
step_sizes = step_size,
|
step_sizes = step_size,
|
||||||
latent_is_noised = True,
|
latent_is_noised = True,
|
||||||
return_pred_only = True
|
return_pred_only = True,
|
||||||
|
return_agent_tokens = True
|
||||||
)
|
)
|
||||||
|
|
||||||
_, pred = unpack(pred, pack_context_shape, 'b * n d')
|
_, pred = unpack(pred, pack_context_shape, 'b * n d')
|
||||||
@ -1310,6 +1317,16 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
noised_latent += flow * (step_size / self.max_steps)
|
noised_latent += flow * (step_size / self.max_steps)
|
||||||
|
|
||||||
|
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
||||||
|
|
||||||
|
if return_rewards_per_frame:
|
||||||
|
reward_logits = self.to_reward_pred(agent_embed[:, -1:])
|
||||||
|
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
|
||||||
|
|
||||||
|
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
||||||
|
|
||||||
|
# concat the denoised latent
|
||||||
|
|
||||||
latents = cat((latents, noised_latent), dim = 1)
|
latents = cat((latents, noised_latent), dim = 1)
|
||||||
|
|
||||||
# returning video
|
# returning video
|
||||||
@ -1318,7 +1335,10 @@ class DynamicsModel(Module):
|
|||||||
return_decoded_video = default(return_decoded_video, has_tokenizer)
|
return_decoded_video = default(return_decoded_video, has_tokenizer)
|
||||||
|
|
||||||
if not return_decoded_video:
|
if not return_decoded_video:
|
||||||
return denoised_latents
|
if not return_rewards_per_frame:
|
||||||
|
return denoised_latents
|
||||||
|
|
||||||
|
return denoised_latents, decoded_rewards
|
||||||
|
|
||||||
generated_video = self.video_tokenizer.decode(
|
generated_video = self.video_tokenizer.decode(
|
||||||
latents,
|
latents,
|
||||||
@ -1326,7 +1346,10 @@ class DynamicsModel(Module):
|
|||||||
width = image_width
|
width = image_width
|
||||||
)
|
)
|
||||||
|
|
||||||
return generated_video
|
if not return_rewards_per_frame:
|
||||||
|
return generated_video
|
||||||
|
|
||||||
|
return generated_video, decoded_rewards
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -1340,7 +1363,8 @@ class DynamicsModel(Module):
|
|||||||
rewards = None, # (b t)
|
rewards = None, # (b t)
|
||||||
return_pred_only = False,
|
return_pred_only = False,
|
||||||
latent_is_noised = False,
|
latent_is_noised = False,
|
||||||
return_all_losses = False
|
return_all_losses = False,
|
||||||
|
return_agent_tokens = False
|
||||||
):
|
):
|
||||||
# handle video or latents
|
# handle video or latents
|
||||||
|
|
||||||
@ -1416,16 +1440,16 @@ class DynamicsModel(Module):
|
|||||||
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
|
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
|
||||||
# in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this
|
# in original shortcut paper, they actually set d = 0 for some reason, look into that later, as there is no mention in the dreamer paper of doing this
|
||||||
|
|
||||||
step_sizes_log2 = torch.zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
|
step_sizes_log2 = zeros((batch,), device = device).long() # zero because zero is equivalent to step size of 1
|
||||||
signal_levels = torch.randint(0, self.max_steps, (batch, time), device = device)
|
signal_levels = randint(0, self.max_steps, (batch, time), device = device)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
# now we follow eq (4)
|
# now we follow eq (4)
|
||||||
|
|
||||||
step_sizes_log2 = torch.randint(1, self.num_step_sizes_log2, (batch,), device = device)
|
step_sizes_log2 = randint(1, self.num_step_sizes_log2, (batch,), device = device)
|
||||||
num_step_sizes = 2 ** step_sizes_log2
|
num_step_sizes = 2 ** step_sizes_log2
|
||||||
|
|
||||||
signal_levels = torch.randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
|
signal_levels = randint(0, self.max_steps, (batch, time)) // num_step_sizes[:, None] * num_step_sizes[:, None] # times are discretized to step sizes
|
||||||
|
|
||||||
# times is from 0 to 1
|
# times is from 0 to 1
|
||||||
|
|
||||||
@ -1434,7 +1458,7 @@ class DynamicsModel(Module):
|
|||||||
if not latent_is_noised:
|
if not latent_is_noised:
|
||||||
# get the noise
|
# get the noise
|
||||||
|
|
||||||
noise = torch.randn_like(latents)
|
noise = randn_like(latents)
|
||||||
|
|
||||||
# noise from 0 as noise to 1 as data
|
# noise from 0 as noise to 1 as data
|
||||||
|
|
||||||
@ -1566,7 +1590,10 @@ class DynamicsModel(Module):
|
|||||||
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True)
|
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True)
|
||||||
|
|
||||||
if return_pred_only:
|
if return_pred_only:
|
||||||
return pred
|
if not return_agent_tokens:
|
||||||
|
return pred
|
||||||
|
|
||||||
|
return pred, encoded_agent_tokens
|
||||||
|
|
||||||
# determine the target for the loss
|
# determine the target for the loss
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.5"
|
version = "0.0.6"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -90,14 +90,16 @@ def test_e2e(
|
|||||||
|
|
||||||
# generating
|
# generating
|
||||||
|
|
||||||
generated_video = dynamics.generate(
|
generated_video, generated_rewards = dynamics.generate(
|
||||||
time_steps = 10,
|
time_steps = 10,
|
||||||
image_height = 128,
|
image_height = 128,
|
||||||
image_width = 128,
|
image_width = 128,
|
||||||
batch_size = 2
|
batch_size = 2,
|
||||||
|
return_rewards_per_frame = True
|
||||||
)
|
)
|
||||||
|
|
||||||
assert generated_video.shape == (2, 3, 10, 128, 128)
|
assert generated_video.shape == (2, 3, 10, 128, 128)
|
||||||
|
assert generated_rewards.shape == (2, 10)
|
||||||
|
|
||||||
# rl
|
# rl
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user