try to fix ci
This commit is contained in:
parent
e316499047
commit
6f1a7a24ed
@ -1701,6 +1701,7 @@ class DynamicsWorldModel(Module):
|
|||||||
attn_kwargs: dict = dict(
|
attn_kwargs: dict = dict(
|
||||||
heads = 8,
|
heads = 8,
|
||||||
),
|
),
|
||||||
|
transformer_kwargs: dict = dict(),
|
||||||
attn_dim_head = 64,
|
attn_dim_head = 64,
|
||||||
attn_softclamp_value = 50.,
|
attn_softclamp_value = 50.,
|
||||||
ff_kwargs: dict = dict(),
|
ff_kwargs: dict = dict(),
|
||||||
@ -1888,7 +1889,9 @@ class DynamicsWorldModel(Module):
|
|||||||
ff_kwargs = ff_kwargs,
|
ff_kwargs = ff_kwargs,
|
||||||
num_residual_streams = num_residual_streams,
|
num_residual_streams = num_residual_streams,
|
||||||
num_special_spatial_tokens = num_agents,
|
num_special_spatial_tokens = num_agents,
|
||||||
final_norm = False
|
time_block_every = time_block_every,
|
||||||
|
final_norm = False,
|
||||||
|
**transformer_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# ppo related
|
# ppo related
|
||||||
|
|||||||
@ -457,8 +457,9 @@ def test_tokenizer_trainer():
|
|||||||
|
|
||||||
tokenizer = VideoTokenizer(
|
tokenizer = VideoTokenizer(
|
||||||
16,
|
16,
|
||||||
encoder_depth = 4,
|
encoder_depth = 1,
|
||||||
decoder_depth = 4,
|
decoder_depth = 1,
|
||||||
|
time_block_every = 1,
|
||||||
dim_latent = 16,
|
dim_latent = 16,
|
||||||
patch_size = 32,
|
patch_size = 32,
|
||||||
attn_dim_head = 16,
|
attn_dim_head = 16,
|
||||||
@ -508,8 +509,9 @@ def test_bc_trainer(
|
|||||||
|
|
||||||
tokenizer = VideoTokenizer(
|
tokenizer = VideoTokenizer(
|
||||||
16,
|
16,
|
||||||
encoder_depth = 4,
|
encoder_depth = 1,
|
||||||
decoder_depth = 4,
|
decoder_depth = 1,
|
||||||
|
time_block_every = 1,
|
||||||
dim_latent = 16,
|
dim_latent = 16,
|
||||||
patch_size = 32,
|
patch_size = 32,
|
||||||
attn_dim_head = 16,
|
attn_dim_head = 16,
|
||||||
@ -523,7 +525,8 @@ def test_bc_trainer(
|
|||||||
max_steps = 64,
|
max_steps = 64,
|
||||||
num_tasks = 4,
|
num_tasks = 4,
|
||||||
num_latent_tokens = 1,
|
num_latent_tokens = 1,
|
||||||
depth = 4,
|
depth = 1,
|
||||||
|
time_block_every = 1,
|
||||||
num_spatial_tokens = 1,
|
num_spatial_tokens = 1,
|
||||||
pred_orig_latent = True,
|
pred_orig_latent = True,
|
||||||
num_discrete_actions = 4,
|
num_discrete_actions = 4,
|
||||||
@ -551,7 +554,8 @@ def test_cache_generate():
|
|||||||
max_steps = 64,
|
max_steps = 64,
|
||||||
num_tasks = 4,
|
num_tasks = 4,
|
||||||
num_latent_tokens = 4,
|
num_latent_tokens = 4,
|
||||||
depth = 4,
|
depth = 1,
|
||||||
|
time_block_every = 1,
|
||||||
num_spatial_tokens = 1,
|
num_spatial_tokens = 1,
|
||||||
pred_orig_latent = True,
|
pred_orig_latent = True,
|
||||||
num_discrete_actions = 4,
|
num_discrete_actions = 4,
|
||||||
@ -570,4 +574,4 @@ def test_online_rl(
|
|||||||
):
|
):
|
||||||
from dreamer4.mocks import MockEnv
|
from dreamer4.mocks import MockEnv
|
||||||
|
|
||||||
mock_env = MockEnv((256, 256), vectorized = vectorized, batch_size = 4)
|
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
|
||||||
Loading…
x
Reference in New Issue
Block a user