x-space and v-space prediction in dynamics model
This commit is contained in:
parent
8b66b703e0
commit
49082d8629
@ -509,6 +509,7 @@ class DynamicsModel(Module):
|
|||||||
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
|
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
|
||||||
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
||||||
depth = 4,
|
depth = 4,
|
||||||
|
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
||||||
time_block_every = 4, # every 4th block is time
|
time_block_every = 4, # every 4th block is time
|
||||||
attn_kwargs: dict = dict(
|
attn_kwargs: dict = dict(
|
||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
@ -532,9 +533,14 @@ class DynamicsModel(Module):
|
|||||||
assert divisible_by(dim, 2)
|
assert divisible_by(dim, 2)
|
||||||
dim_half = dim // 2
|
dim_half = dim // 2
|
||||||
|
|
||||||
|
self.num_signal_levels = num_signal_levels
|
||||||
|
self.num_step_sizes = num_step_sizes
|
||||||
|
|
||||||
self.signal_levels_embed = nn.Embedding(num_signal_levels, dim_half)
|
self.signal_levels_embed = nn.Embedding(num_signal_levels, dim_half)
|
||||||
self.step_sizes_embed = nn.Embedding(num_step_sizes, dim_half)
|
self.step_sizes_embed = nn.Embedding(num_step_sizes, dim_half)
|
||||||
|
|
||||||
|
self.pred_orig_latent = pred_orig_latent
|
||||||
|
|
||||||
# they sum all the actions into a single token
|
# they sum all the actions into a single token
|
||||||
|
|
||||||
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
|
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
|
||||||
@ -573,6 +579,32 @@ class DynamicsModel(Module):
|
|||||||
step_sizes = None # (b t)
|
step_sizes = None # (b t)
|
||||||
):
|
):
|
||||||
|
|
||||||
|
assert not (exists(signal_levels) ^ exists(step_sizes))
|
||||||
|
|
||||||
|
flow_matching = exists(signal_levels)
|
||||||
|
|
||||||
|
# flow matching if `signal_levels` passed in
|
||||||
|
|
||||||
|
if flow_matching:
|
||||||
|
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
|
||||||
|
interp = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1')
|
||||||
|
|
||||||
|
orig_latents = latents
|
||||||
|
|
||||||
|
latents = noise.lerp(latents, interp)
|
||||||
|
|
||||||
|
# allow for original velocity pred
|
||||||
|
# x-space as in paper is in else clause
|
||||||
|
|
||||||
|
if not self.pred_orig_latent:
|
||||||
|
pred_target = flow = latents - noise
|
||||||
|
else:
|
||||||
|
pred_target = latents
|
||||||
|
|
||||||
|
# latents to spatial tokens
|
||||||
|
|
||||||
space_tokens = self.latents_to_spatial_tokens(latents)
|
space_tokens = self.latents_to_spatial_tokens(latents)
|
||||||
|
|
||||||
# pack to tokens
|
# pack to tokens
|
||||||
@ -584,14 +616,13 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
# determine signal + step size embed for their diffusion forcing + shortcut
|
# determine signal + step size embed for their diffusion forcing + shortcut
|
||||||
|
|
||||||
assert not (exists(signal_levels) ^ exists(step_sizes))
|
|
||||||
|
|
||||||
if exists(signal_levels):
|
if exists(signal_levels):
|
||||||
signal_embed = self.signal_levels_embed(signal_levels)
|
signal_embed = self.signal_levels_embed(signal_levels)
|
||||||
step_size_embed = self.step_sizes_embed(step_sizes)
|
step_size_embed = self.step_sizes_embed(step_sizes)
|
||||||
|
|
||||||
flow_token = cat((signal_embed, step_size_embed), dim = -1)
|
flow_token = cat((signal_embed, step_size_embed), dim = -1)
|
||||||
flow_token = rearrange(flow_token, 'b t d -> b t d')
|
flow_token = rearrange(flow_token, 'b t d -> b t d')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
flow_token = registers[..., 0:0, :]
|
flow_token = registers[..., 0:0, :]
|
||||||
|
|
||||||
@ -619,7 +650,12 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
pooled = reduce(space_tokens, 'b t s d -> b t d', 'mean')
|
pooled = reduce(space_tokens, 'b t s d -> b t d', 'mean')
|
||||||
|
|
||||||
return self.to_pred(pooled)
|
pred = self.to_pred(pooled)
|
||||||
|
|
||||||
|
if not flow_matching:
|
||||||
|
return pred
|
||||||
|
|
||||||
|
return F.mse_loss(pred, pred_target)
|
||||||
|
|
||||||
# dreamer
|
# dreamer
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,11 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
param = pytest.mark.parametrize
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def test_e2e():
|
@param('pred_orig_latent', (False, True))
|
||||||
|
def test_e2e(
|
||||||
|
pred_orig_latent
|
||||||
|
):
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
||||||
|
|
||||||
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32)
|
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32)
|
||||||
@ -13,10 +17,10 @@ def test_e2e():
|
|||||||
latents = tokenizer(x, return_latents = True)
|
latents = tokenizer(x, return_latents = True)
|
||||||
assert latents.shape[-1] == 32
|
assert latents.shape[-1] == 32
|
||||||
|
|
||||||
dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32)
|
dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32, pred_orig_latent = pred_orig_latent)
|
||||||
|
|
||||||
signal_levels = torch.randint(0, 500, (2, 4))
|
signal_levels = torch.randint(0, 500, (2, 4))
|
||||||
step_sizes = torch.randint(0, 32, (2, 4))
|
step_sizes = torch.randint(0, 32, (2, 4))
|
||||||
|
|
||||||
pred = dynamics(latents, signal_levels = signal_levels, step_sizes = step_sizes)
|
flow_loss = dynamics(latents, signal_levels = signal_levels, step_sizes = step_sizes)
|
||||||
assert pred.shape == latents.shape
|
assert flow_loss.numel() == 1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user