must be able to do phase1 and phase2 training

This commit is contained in:
lucidrains 2025-10-09 08:04:36 -07:00
parent e2d86a4543
commit 31f4363be7

View File

@ -9,6 +9,7 @@ import torch
@param('add_task_embeds', (False, True))
@param('num_spatial_tokens', (2, 8))
@param('signal_and_step_passed_in', (False, True))
@param('condition_on_actions', (False, True))
@param('add_reward_embed_to_agent_token', (False, True))
def test_e2e(
pred_orig_latent,
@ -18,6 +19,7 @@ def test_e2e(
add_task_embeds,
num_spatial_tokens,
signal_and_step_passed_in,
condition_on_actions,
add_reward_embed_to_agent_token
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
@ -80,7 +82,9 @@ def test_e2e(
if add_task_embeds:
tasks = torch.randint(0, 4, (2,))
actions = torch.randint(0, 4, (2, 4, 1))
actions = None
if condition_on_actions:
actions = torch.randint(0, 4, (2, 4, 1))
flow_loss = dynamics(
**dynamics_input,