From 283d59d75a25a9ee6e610f129aa5c3a811dc0c22 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 21 Oct 2025 09:50:07 -0700 Subject: [PATCH] oops --- dreamer4/dreamer4.py | 1 + pyproject.toml | 2 +- tests/test_dreamer.py | 11 ++++++----- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 2cd34d2..de090e5 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1251,6 +1251,7 @@ class AxialSpaceTimeTransformer(Module): final_norm = True ): super().__init__() + assert depth >= time_block_every, f'depth must be at least {time_block_every}' # hyper connections diff --git a/pyproject.toml b/pyproject.toml index 0c68faf..bab94eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.55" +version = "0.0.56" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index b623759..94a395d 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -30,8 +30,8 @@ def test_e2e( tokenizer = VideoTokenizer( 16, - encoder_depth = 1, - decoder_depth = 1, + encoder_depth = 4, + decoder_depth = 4, dim_latent = 16, patch_size = 32, attn_dim_head = 16, @@ -197,8 +197,8 @@ def test_action_with_world_model(): 512, dim_latent = 32, patch_size = 32, - encoder_depth = 1, - decoder_depth = 1, + encoder_depth = 4, + decoder_depth = 4, attn_heads = 8, image_height = 256, image_width = 256, @@ -212,6 +212,7 @@ def test_action_with_world_model(): num_agents = 1, video_tokenizer = tokenizer, dim_latent = 32, + depth = 4, num_discrete_actions = 4 ) @@ -483,7 +484,7 @@ def test_cache_generate(): max_steps = 64, num_tasks = 4, num_latent_tokens = 4, - depth = 1, + depth = 4, num_spatial_tokens = 1, pred_orig_latent = True, num_discrete_actions = 4,