diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 1479999..d494f77 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -528,11 +528,11 @@ class DynamicsModel(Module): dim_latent, num_signal_levels = 500, num_step_sizes = 32, - num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction) - num_register_tokens = 8, # they claim register tokens led to better temporal consistency + num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction) + num_register_tokens = 8, # they claim register tokens led to better temporal consistency depth = 4, - pred_is_clean_latents = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space) - time_block_every = 4, # every 4th block is time + pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space) + time_block_every = 4, # every 4th block is time attn_kwargs: dict = dict( dim_head = 64, heads = 8, @@ -561,7 +561,7 @@ class DynamicsModel(Module): self.signal_levels_embed = nn.Embedding(num_signal_levels, dim_half) self.step_sizes_embed = nn.Embedding(num_step_sizes, dim_half) - self.pred_is_clean_latents = pred_is_clean_latents + self.pred_orig_latent = pred_orig_latent # they sum all the actions into a single token @@ -611,15 +611,19 @@ class DynamicsModel(Module): noise = torch.randn_like(latents) - times = signal_levels.float() / self.num_signal_levels + interp = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1') - times = rearrange(times, 'b t -> b t 1') + orig_latents = latents - flow = latents - noise + latents = noise.lerp(latents, interp) - latents = noise.lerp(latents, times) + # allow for original velocity pred + # x-space as in paper is in else clause - noised_latents = latents + if not self.pred_orig_latent: + pred_target = flow = latents - noise + else: + pred_target = latents # latents to spatial tokens @@ -673,15 +677,7 @@ class DynamicsModel(Module): if not flow_matching: return pred - # x-space vs v-space - - if self.pred_is_clean_latents: - denoised_latent = pred - pred_flow = (denoised_latent - noised_latents) / (1. - times) - else: - pred_flow = pred - - return F.mse_loss(pred_flow, flow) + return F.mse_loss(pred, pred_target) # dreamer diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 9a04d92..d64ac69 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -2,9 +2,9 @@ import pytest param = pytest.mark.parametrize import torch -@param('pred_is_clean_latents', (False, True)) +@param('pred_orig_latent', (False, True)) def test_e2e( - pred_is_clean_latents + pred_orig_latent ): from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel @@ -17,7 +17,7 @@ def test_e2e( latents = tokenizer(x, return_latents = True) assert latents.shape[-1] == 32 - dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32, pred_is_clean_latents = pred_is_clean_latents) + dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32, pred_orig_latent = pred_orig_latent) signal_levels = torch.randint(0, 500, (2, 4)) step_sizes = torch.randint(0, 32, (2, 4))