From 31f4363be772fa8c01921ba3010fafb9415fb10a Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 9 Oct 2025 08:04:36 -0700 Subject: [PATCH] must be able to do phase1 and phase2 training --- tests/test_dreamer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 09f758b..bd0f028 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -9,6 +9,7 @@ import torch @param('add_task_embeds', (False, True)) @param('num_spatial_tokens', (2, 8)) @param('signal_and_step_passed_in', (False, True)) +@param('condition_on_actions', (False, True)) @param('add_reward_embed_to_agent_token', (False, True)) def test_e2e( pred_orig_latent, @@ -18,6 +19,7 @@ def test_e2e( add_task_embeds, num_spatial_tokens, signal_and_step_passed_in, + condition_on_actions, add_reward_embed_to_agent_token ): from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel @@ -80,7 +82,9 @@ def test_e2e( if add_task_embeds: tasks = torch.randint(0, 4, (2,)) - actions = torch.randint(0, 4, (2, 4, 1)) + actions = None + if condition_on_actions: + actions = torch.randint(0, 4, (2, 4, 1)) flow_loss = dynamics( **dynamics_input,