From 6f1a7a24ed001bcd1d72c8aee84d52f188fa3542 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 21 Oct 2025 11:47:39 -0700 Subject: [PATCH] try to fix ci --- dreamer4/dreamer4.py | 5 ++++- tests/test_dreamer.py | 18 +++++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index de090e5..9902535 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1701,6 +1701,7 @@ class DynamicsWorldModel(Module): attn_kwargs: dict = dict( heads = 8, ), + transformer_kwargs: dict = dict(), attn_dim_head = 64, attn_softclamp_value = 50., ff_kwargs: dict = dict(), @@ -1888,7 +1889,9 @@ class DynamicsWorldModel(Module): ff_kwargs = ff_kwargs, num_residual_streams = num_residual_streams, num_special_spatial_tokens = num_agents, - final_norm = False + time_block_every = time_block_every, + final_norm = False, + **transformer_kwargs ) # ppo related diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 3acf1e5..7a715f5 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -457,8 +457,9 @@ def test_tokenizer_trainer(): tokenizer = VideoTokenizer( 16, - encoder_depth = 4, - decoder_depth = 4, + encoder_depth = 1, + decoder_depth = 1, + time_block_every = 1, dim_latent = 16, patch_size = 32, attn_dim_head = 16, @@ -508,8 +509,9 @@ def test_bc_trainer( tokenizer = VideoTokenizer( 16, - encoder_depth = 4, - decoder_depth = 4, + encoder_depth = 1, + decoder_depth = 1, + time_block_every = 1, dim_latent = 16, patch_size = 32, attn_dim_head = 16, @@ -523,7 +525,8 @@ def test_bc_trainer( max_steps = 64, num_tasks = 4, num_latent_tokens = 1, - depth = 4, + depth = 1, + time_block_every = 1, num_spatial_tokens = 1, pred_orig_latent = True, num_discrete_actions = 4, @@ -551,7 +554,8 @@ def test_cache_generate(): max_steps = 64, num_tasks = 4, num_latent_tokens = 4, - depth = 4, + depth = 1, + time_block_every = 1, num_spatial_tokens = 1, pred_orig_latent = True, num_discrete_actions = 4, @@ -570,4 +574,4 @@ def test_online_rl( ): from dreamer4.mocks import MockEnv - mock_env = MockEnv((256, 256), vectorized = vectorized, batch_size = 4) \ No newline at end of file + mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4) \ No newline at end of file