From 36ccb0850062470573eca288aee6b8507c6f1983 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 7 Oct 2025 08:36:46 -0700 Subject: [PATCH] allow for step_sizes to be passed in, log2 is not that intuitive --- .github/workflows/test.yml | 5 +++-- dreamer4/dreamer4.py | 42 +++++++++++++++++++++++++++++++------- pyproject.toml | 2 +- tests/test_dreamer.py | 4 ++-- 4 files changed, 41 insertions(+), 12 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d97c9fb..477de16 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,8 +14,9 @@ jobs: python-version: "3.10" - name: Install dependencies run: | - python -m pip install --upgrade pip - python -m pip install -e .[test] + python -m pip install uv + python -m uv pip install --upgrade pip + python -m uv pip install -e .[test] - name: Test with pytest run: | python -m pytest tests/ diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 1fe95d4..351db4e 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1264,7 +1264,6 @@ class DynamicsModel(Module): # derive step size step_size = self.max_steps // num_steps - step_size_log2 = tensor(log2(step_size), dtype = torch.long, device = self.device) # denoising # teacher forcing to start with @@ -1288,7 +1287,8 @@ class DynamicsModel(Module): pred = self.forward( latents = noised_latent_with_context, signal_levels = signal_levels_with_context, - step_sizes_log2 = step_size_log2, + step_sizes = step_size, + latent_is_noised = True, return_pred_only = True ) @@ -1330,10 +1330,12 @@ class DynamicsModel(Module): video = None, latents = None, # (b t n d) | (b t d) signal_levels = None, # () | (b) | (b t) + step_sizes = None, # () | (b) step_sizes_log2 = None, # () | (b) tasks = None, # (b) rewards = None, # (b t) return_pred_only = False, + latent_is_noised = False, return_all_losses = False ): # handle video or latents @@ -1354,24 +1356,50 @@ class DynamicsModel(Module): batch, time, device = *latents.shape[:2], latents.device - # shape related + # signal and step size related input conforming if exists(signal_levels): + if isinstance(signal_levels, int): + signal_levels = tensor(signal_levels, device = self.device) + if signal_levels.ndim == 0: signal_levels = repeat(signal_levels, '-> b', b = batch) if signal_levels.ndim == 1: signal_levels = repeat(signal_levels, 'b -> b t', t = time) - if exists(step_sizes_log2) and step_sizes_log2.ndim == 0: - step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch) + if exists(step_sizes): + if isinstance(step_sizes, int): + step_sizes = tensor(step_sizes, device = self.device) + + if step_sizes.ndim == 0: + step_sizes = repeat(step_sizes, '-> b', b = batch) + + if exists(step_sizes_log2): + if isinstance(step_sizes_log2, int): + step_sizes_log2 = tensor(step_sizes_log2, device = self.device) + + if step_sizes_log2.ndim == 0: + step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch) + + # handle step sizes -> step size log2 + + assert not (exists(step_sizes) and exists(step_sizes_log2)) + + if exists(step_sizes): + step_sizes_log2_maybe_float = torch.log2(step_sizes) + step_sizes_log2 = step_sizes_log2_maybe_float.long() + + assert (step_sizes_log2 == step_sizes_log2_maybe_float).all(), f'`step_sizes` must be powers of 2' # flow related assert not (exists(signal_levels) ^ exists(step_sizes_log2)) is_inference = exists(signal_levels) - return_pred_only = is_inference + no_shortcut_train = not is_inference + + return_pred_only = return_pred_only or latent_is_noised # if neither signal levels or step sizes passed in, assume training # generate them randomly for training @@ -1399,7 +1427,7 @@ class DynamicsModel(Module): times = self.get_times_from_signal_level(signal_levels, latents) - if not is_inference: + if not latent_is_noised: # get the noise noise = torch.randn_like(latents) diff --git a/pyproject.toml b/pyproject.toml index 3fa384a..ede90a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.2" +version = "0.0.3" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index d5e735d..bebff04 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -67,8 +67,8 @@ def test_e2e( signal_levels = step_sizes_log2 = None if signal_and_step_passed_in: - signal_levels = torch.randint(0, 64, (2, 4)) - step_sizes_log2 = torch.randint(1, 6, (2,)) + signal_levels = torch.randint(0, 32, (2, 4)) + step_sizes_log2 = torch.randint(1, 5, (2,)) if dynamics_with_video_input: dynamics_input = dict(video = video)