must be able to do phase1 and phase2 training
This commit is contained in:
parent
e2d86a4543
commit
31f4363be7
@ -9,6 +9,7 @@ import torch
|
||||
@param('add_task_embeds', (False, True))
|
||||
@param('num_spatial_tokens', (2, 8))
|
||||
@param('signal_and_step_passed_in', (False, True))
|
||||
@param('condition_on_actions', (False, True))
|
||||
@param('add_reward_embed_to_agent_token', (False, True))
|
||||
def test_e2e(
|
||||
pred_orig_latent,
|
||||
@ -18,6 +19,7 @@ def test_e2e(
|
||||
add_task_embeds,
|
||||
num_spatial_tokens,
|
||||
signal_and_step_passed_in,
|
||||
condition_on_actions,
|
||||
add_reward_embed_to_agent_token
|
||||
):
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
@ -80,7 +82,9 @@ def test_e2e(
|
||||
if add_task_embeds:
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
|
||||
actions = torch.randint(0, 4, (2, 4, 1))
|
||||
actions = None
|
||||
if condition_on_actions:
|
||||
actions = torch.randint(0, 4, (2, 4, 1))
|
||||
|
||||
flow_loss = dynamics(
|
||||
**dynamics_input,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user