diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 09f758b..bd0f028 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -9,6 +9,7 @@ import torch @param('add_task_embeds', (False, True)) @param('num_spatial_tokens', (2, 8)) @param('signal_and_step_passed_in', (False, True)) +@param('condition_on_actions', (False, True)) @param('add_reward_embed_to_agent_token', (False, True)) def test_e2e( pred_orig_latent, @@ -18,6 +19,7 @@ def test_e2e( add_task_embeds, num_spatial_tokens, signal_and_step_passed_in, + condition_on_actions, add_reward_embed_to_agent_token ): from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel @@ -80,7 +82,9 @@ def test_e2e( if add_task_embeds: tasks = torch.randint(0, 4, (2,)) - actions = torch.randint(0, 4, (2, 4, 1)) + actions = None + if condition_on_actions: + actions = torch.randint(0, 4, (2, 4, 1)) flow_loss = dynamics( **dynamics_input,