diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index d494f77..1479999 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -528,11 +528,11 @@ class DynamicsModel(Module): dim_latent, num_signal_levels = 500, num_step_sizes = 32, - num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction) - num_register_tokens = 8, # they claim register tokens led to better temporal consistency + num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction) + num_register_tokens = 8, # they claim register tokens led to better temporal consistency depth = 4, - pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space) - time_block_every = 4, # every 4th block is time + pred_is_clean_latents = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space) + time_block_every = 4, # every 4th block is time attn_kwargs: dict = dict( dim_head = 64, heads = 8, @@ -561,7 +561,7 @@ class DynamicsModel(Module): self.signal_levels_embed = nn.Embedding(num_signal_levels, dim_half) self.step_sizes_embed = nn.Embedding(num_step_sizes, dim_half) - self.pred_orig_latent = pred_orig_latent + self.pred_is_clean_latents = pred_is_clean_latents # they sum all the actions into a single token @@ -611,19 +611,15 @@ class DynamicsModel(Module): noise = torch.randn_like(latents) - interp = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1') + times = signal_levels.float() / self.num_signal_levels - orig_latents = latents + times = rearrange(times, 'b t -> b t 1') - latents = noise.lerp(latents, interp) + flow = latents - noise - # allow for original velocity pred - # x-space as in paper is in else clause + latents = noise.lerp(latents, times) - if not self.pred_orig_latent: - pred_target = flow = latents - noise - else: - pred_target = latents + noised_latents = latents # latents to spatial tokens @@ -677,7 +673,15 @@ class DynamicsModel(Module): if not flow_matching: return pred - return F.mse_loss(pred, pred_target) + # x-space vs v-space + + if self.pred_is_clean_latents: + denoised_latent = pred + pred_flow = (denoised_latent - noised_latents) / (1. - times) + else: + pred_flow = pred + + return F.mse_loss(pred_flow, flow) # dreamer diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index d64ac69..9a04d92 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -2,9 +2,9 @@ import pytest param = pytest.mark.parametrize import torch -@param('pred_orig_latent', (False, True)) +@param('pred_is_clean_latents', (False, True)) def test_e2e( - pred_orig_latent + pred_is_clean_latents ): from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel @@ -17,7 +17,7 @@ def test_e2e( latents = tokenizer(x, return_latents = True) assert latents.shape[-1] == 32 - dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32, pred_orig_latent = pred_orig_latent) + dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32, pred_is_clean_latents = pred_is_clean_latents) signal_levels = torch.randint(0, 500, (2, 4)) step_sizes = torch.randint(0, 32, (2, 4))