try to fix ci

This commit is contained in:
lucidrains 2025-10-21 11:47:39 -07:00
parent e316499047
commit 6f1a7a24ed
2 changed files with 15 additions and 8 deletions

View File

@ -1701,6 +1701,7 @@ class DynamicsWorldModel(Module):
attn_kwargs: dict = dict( attn_kwargs: dict = dict(
heads = 8, heads = 8,
), ),
transformer_kwargs: dict = dict(),
attn_dim_head = 64, attn_dim_head = 64,
attn_softclamp_value = 50., attn_softclamp_value = 50.,
ff_kwargs: dict = dict(), ff_kwargs: dict = dict(),
@ -1888,7 +1889,9 @@ class DynamicsWorldModel(Module):
ff_kwargs = ff_kwargs, ff_kwargs = ff_kwargs,
num_residual_streams = num_residual_streams, num_residual_streams = num_residual_streams,
num_special_spatial_tokens = num_agents, num_special_spatial_tokens = num_agents,
final_norm = False time_block_every = time_block_every,
final_norm = False,
**transformer_kwargs
) )
# ppo related # ppo related

View File

@ -457,8 +457,9 @@ def test_tokenizer_trainer():
tokenizer = VideoTokenizer( tokenizer = VideoTokenizer(
16, 16,
encoder_depth = 4, encoder_depth = 1,
decoder_depth = 4, decoder_depth = 1,
time_block_every = 1,
dim_latent = 16, dim_latent = 16,
patch_size = 32, patch_size = 32,
attn_dim_head = 16, attn_dim_head = 16,
@ -508,8 +509,9 @@ def test_bc_trainer(
tokenizer = VideoTokenizer( tokenizer = VideoTokenizer(
16, 16,
encoder_depth = 4, encoder_depth = 1,
decoder_depth = 4, decoder_depth = 1,
time_block_every = 1,
dim_latent = 16, dim_latent = 16,
patch_size = 32, patch_size = 32,
attn_dim_head = 16, attn_dim_head = 16,
@ -523,7 +525,8 @@ def test_bc_trainer(
max_steps = 64, max_steps = 64,
num_tasks = 4, num_tasks = 4,
num_latent_tokens = 1, num_latent_tokens = 1,
depth = 4, depth = 1,
time_block_every = 1,
num_spatial_tokens = 1, num_spatial_tokens = 1,
pred_orig_latent = True, pred_orig_latent = True,
num_discrete_actions = 4, num_discrete_actions = 4,
@ -551,7 +554,8 @@ def test_cache_generate():
max_steps = 64, max_steps = 64,
num_tasks = 4, num_tasks = 4,
num_latent_tokens = 4, num_latent_tokens = 4,
depth = 4, depth = 1,
time_block_every = 1,
num_spatial_tokens = 1, num_spatial_tokens = 1,
pred_orig_latent = True, pred_orig_latent = True,
num_discrete_actions = 4, num_discrete_actions = 4,
@ -570,4 +574,4 @@ def test_online_rl(
): ):
from dreamer4.mocks import MockEnv from dreamer4.mocks import MockEnv
mock_env = MockEnv((256, 256), vectorized = vectorized, batch_size = 4) mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)