able to accept raw video for dynamics model, if tokenizer passed in
This commit is contained in:
parent
8373cb13ec
commit
895a867a66
@ -698,6 +698,7 @@ class DynamicsModel(Module):
|
|||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
dim_latent,
|
dim_latent,
|
||||||
|
video_tokenizer: VideoTokenizer | None = None,
|
||||||
num_signal_levels = 500,
|
num_signal_levels = 500,
|
||||||
num_step_sizes = 32,
|
num_step_sizes = 32,
|
||||||
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
|
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
|
||||||
@ -714,6 +715,10 @@ class DynamicsModel(Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# can accept raw video if tokenizer is passed in
|
||||||
|
|
||||||
|
self.video_tokenizer = video_tokenizer
|
||||||
|
|
||||||
# spatial and register tokens
|
# spatial and register tokens
|
||||||
|
|
||||||
self.latents_to_spatial_tokens = Sequential(
|
self.latents_to_spatial_tokens = Sequential(
|
||||||
@ -769,12 +774,34 @@ class DynamicsModel(Module):
|
|||||||
Linear(dim, dim_latent)
|
Linear(dim, dim_latent)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def parameter(self):
|
||||||
|
params = super().parameters()
|
||||||
|
|
||||||
|
if not exists(self.video_tokenizer):
|
||||||
|
return params
|
||||||
|
|
||||||
|
return list(set(params) - set(self.video_tokenizer.parameters()))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
latents, # (b t d)
|
*,
|
||||||
|
video = None,
|
||||||
|
latents = None, # (b t d)
|
||||||
signal_levels = None, # (b t)
|
signal_levels = None, # (b t)
|
||||||
step_sizes = None # (b t)
|
step_sizes = None # (b t)
|
||||||
):
|
):
|
||||||
|
# handle video or latents
|
||||||
|
|
||||||
|
assert exists(video) ^ exists(latents)
|
||||||
|
|
||||||
|
if exists(video):
|
||||||
|
assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self.video_tokenizer.eval()
|
||||||
|
latents = self.video_tokenizer(video, return_latents = True)
|
||||||
|
|
||||||
|
# flow related
|
||||||
|
|
||||||
assert not (exists(signal_levels) ^ exists(step_sizes))
|
assert not (exists(signal_levels) ^ exists(step_sizes))
|
||||||
|
|
||||||
|
|||||||
@ -3,26 +3,29 @@ param = pytest.mark.parametrize
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
@param('pred_orig_latent', (False, True))
|
@param('pred_orig_latent', (False, True))
|
||||||
@param('gqa', (False, True))
|
@param('grouped_query_attn', (False, True))
|
||||||
|
@param('dynamics_with_video_input', (False, True))
|
||||||
def test_e2e(
|
def test_e2e(
|
||||||
pred_orig_latent,
|
pred_orig_latent,
|
||||||
gqa
|
grouped_query_attn,
|
||||||
|
dynamics_with_video_input
|
||||||
):
|
):
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
||||||
|
|
||||||
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32)
|
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32)
|
||||||
x = torch.randn(2, 3, 4, 256, 256)
|
video = torch.randn(2, 3, 4, 256, 256)
|
||||||
|
|
||||||
loss = tokenizer(x)
|
loss = tokenizer(video)
|
||||||
assert loss.numel() == 1
|
assert loss.numel() == 1
|
||||||
|
|
||||||
latents = tokenizer(x, return_latents = True)
|
latents = tokenizer(video, return_latents = True)
|
||||||
assert latents.shape[-1] == 32
|
assert latents.shape[-1] == 32
|
||||||
|
|
||||||
query_heads, heads = (16, 4) if gqa else (8, 8)
|
query_heads, heads = (16, 4) if grouped_query_attn else (8, 8)
|
||||||
|
|
||||||
dynamics = DynamicsModel(
|
dynamics = DynamicsModel(
|
||||||
512,
|
512,
|
||||||
|
video_tokenizer = tokenizer,
|
||||||
dim_latent = 32,
|
dim_latent = 32,
|
||||||
num_signal_levels = 500,
|
num_signal_levels = 500,
|
||||||
num_step_sizes = 32,
|
num_step_sizes = 32,
|
||||||
@ -36,7 +39,12 @@ def test_e2e(
|
|||||||
signal_levels = torch.randint(0, 500, (2, 4))
|
signal_levels = torch.randint(0, 500, (2, 4))
|
||||||
step_sizes = torch.randint(0, 32, (2, 4))
|
step_sizes = torch.randint(0, 32, (2, 4))
|
||||||
|
|
||||||
flow_loss = dynamics(latents, signal_levels = signal_levels, step_sizes = step_sizes)
|
if dynamics_with_video_input:
|
||||||
|
dynamics_input = dict(video = video)
|
||||||
|
else:
|
||||||
|
dynamics_input = dict(latents = latents)
|
||||||
|
|
||||||
|
flow_loss = dynamics(**dynamics_input, signal_levels = signal_levels, step_sizes = step_sizes)
|
||||||
assert flow_loss.numel() == 1
|
assert flow_loss.numel() == 1
|
||||||
|
|
||||||
def test_symexp_two_hot():
|
def test_symexp_two_hot():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user