must be able to do phase1 and phase2 training
This commit is contained in:
parent
e2d86a4543
commit
31f4363be7
@ -9,6 +9,7 @@ import torch
|
|||||||
@param('add_task_embeds', (False, True))
|
@param('add_task_embeds', (False, True))
|
||||||
@param('num_spatial_tokens', (2, 8))
|
@param('num_spatial_tokens', (2, 8))
|
||||||
@param('signal_and_step_passed_in', (False, True))
|
@param('signal_and_step_passed_in', (False, True))
|
||||||
|
@param('condition_on_actions', (False, True))
|
||||||
@param('add_reward_embed_to_agent_token', (False, True))
|
@param('add_reward_embed_to_agent_token', (False, True))
|
||||||
def test_e2e(
|
def test_e2e(
|
||||||
pred_orig_latent,
|
pred_orig_latent,
|
||||||
@ -18,6 +19,7 @@ def test_e2e(
|
|||||||
add_task_embeds,
|
add_task_embeds,
|
||||||
num_spatial_tokens,
|
num_spatial_tokens,
|
||||||
signal_and_step_passed_in,
|
signal_and_step_passed_in,
|
||||||
|
condition_on_actions,
|
||||||
add_reward_embed_to_agent_token
|
add_reward_embed_to_agent_token
|
||||||
):
|
):
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||||
@ -80,7 +82,9 @@ def test_e2e(
|
|||||||
if add_task_embeds:
|
if add_task_embeds:
|
||||||
tasks = torch.randint(0, 4, (2,))
|
tasks = torch.randint(0, 4, (2,))
|
||||||
|
|
||||||
actions = torch.randint(0, 4, (2, 4, 1))
|
actions = None
|
||||||
|
if condition_on_actions:
|
||||||
|
actions = torch.randint(0, 4, (2, 4, 1))
|
||||||
|
|
||||||
flow_loss = dynamics(
|
flow_loss = dynamics(
|
||||||
**dynamics_input,
|
**dynamics_input,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user