643 lines
17 KiB
Python
643 lines
17 KiB
Python
import pytest
|
|
param = pytest.mark.parametrize
|
|
import torch
|
|
|
|
@param('pred_orig_latent', (False, True))
|
|
@param('grouped_query_attn', (False, True))
|
|
@param('dynamics_with_video_input', (False, True))
|
|
@param('prob_no_shortcut_train', (None, 0., 1.))
|
|
@param('add_task_embeds', (False, True))
|
|
@param('num_spatial_tokens', (2, 8))
|
|
@param('signal_and_step_passed_in', (False, True))
|
|
@param('condition_on_actions', (False, True))
|
|
@param('num_residual_streams', (1, 4))
|
|
@param('add_reward_embed_to_agent_token', (False, True))
|
|
@param('use_time_kv_cache', (False, True))
|
|
def test_e2e(
|
|
pred_orig_latent,
|
|
grouped_query_attn,
|
|
dynamics_with_video_input,
|
|
prob_no_shortcut_train,
|
|
add_task_embeds,
|
|
num_spatial_tokens,
|
|
signal_and_step_passed_in,
|
|
condition_on_actions,
|
|
num_residual_streams,
|
|
add_reward_embed_to_agent_token,
|
|
use_time_kv_cache
|
|
):
|
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
|
|
|
tokenizer = VideoTokenizer(
|
|
16,
|
|
encoder_depth = 4,
|
|
decoder_depth = 4,
|
|
dim_latent = 16,
|
|
patch_size = 32,
|
|
attn_dim_head = 16,
|
|
num_latent_tokens = 4,
|
|
num_residual_streams = num_residual_streams
|
|
)
|
|
|
|
video = torch.randn(2, 3, 4, 256, 256)
|
|
|
|
loss = tokenizer(video)
|
|
assert loss.numel() == 1
|
|
|
|
latents = tokenizer(video, return_latents = True)
|
|
assert latents.shape[-1] == 16
|
|
|
|
recon = tokenizer.decode(latents, 256, 256)
|
|
assert recon.shape == video.shape
|
|
|
|
query_heads, heads = (16, 4) if grouped_query_attn else (8, 8)
|
|
|
|
dynamics = DynamicsWorldModel(
|
|
dim = 16,
|
|
video_tokenizer = tokenizer,
|
|
dim_latent = 16,
|
|
max_steps = 64,
|
|
num_tasks = 4,
|
|
num_latent_tokens = 4,
|
|
depth = 4,
|
|
num_spatial_tokens = num_spatial_tokens,
|
|
pred_orig_latent = pred_orig_latent,
|
|
num_discrete_actions = 4,
|
|
attn_dim_head = 16,
|
|
attn_kwargs = dict(
|
|
heads = heads,
|
|
query_heads = query_heads,
|
|
),
|
|
prob_no_shortcut_train = prob_no_shortcut_train,
|
|
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
|
|
num_residual_streams = num_residual_streams
|
|
)
|
|
|
|
signal_levels = step_sizes_log2 = None
|
|
|
|
if signal_and_step_passed_in:
|
|
signal_levels = torch.randint(0, 32, (2, 4))
|
|
step_sizes_log2 = torch.randint(1, 5, (2,))
|
|
|
|
if dynamics_with_video_input:
|
|
dynamics_input = dict(video = video)
|
|
else:
|
|
dynamics_input = dict(latents = latents)
|
|
|
|
tasks = None
|
|
if add_task_embeds:
|
|
tasks = torch.randint(0, 4, (2,))
|
|
|
|
actions = None
|
|
if condition_on_actions:
|
|
actions = torch.randint(0, 4, (2, 3, 1))
|
|
|
|
flow_loss = dynamics(
|
|
**dynamics_input,
|
|
tasks = tasks,
|
|
signal_levels = signal_levels,
|
|
step_sizes_log2 = step_sizes_log2,
|
|
discrete_actions = actions,
|
|
add_autoregressive_action_loss = True
|
|
)
|
|
|
|
assert flow_loss.numel() == 1
|
|
|
|
# generating
|
|
|
|
generations = dynamics.generate(
|
|
time_steps = 10,
|
|
image_height = 128,
|
|
image_width = 128,
|
|
batch_size = 2,
|
|
return_rewards_per_frame = True,
|
|
use_time_kv_cache = use_time_kv_cache
|
|
)
|
|
|
|
assert generations.video.shape == (2, 3, 10, 128, 128)
|
|
assert generations.rewards.shape == (2, 10)
|
|
|
|
# rl
|
|
|
|
rewards = torch.randn((2, 4)) * 100.
|
|
|
|
flow_loss = dynamics(
|
|
**dynamics_input,
|
|
tasks = tasks,
|
|
rewards = rewards
|
|
)
|
|
|
|
def test_symexp_two_hot():
|
|
import torch
|
|
from dreamer4.dreamer4 import SymExpTwoHot
|
|
|
|
two_hot_encoder = SymExpTwoHot(
|
|
(-3., 3.),
|
|
num_bins = 20,
|
|
learned_embedding = True,
|
|
dim_embed = 512
|
|
)
|
|
|
|
values = torch.randn((10))
|
|
|
|
two_hot_encoded = two_hot_encoder(values)
|
|
recon_values = two_hot_encoder.bins_to_scalar_value(two_hot_encoded)
|
|
|
|
assert torch.allclose(recon_values, values, atol = 1e-6)
|
|
|
|
reward_embeds = two_hot_encoder.embed(two_hot_encoded)
|
|
assert reward_embeds.shape == (10, 512)
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason = 'no cuda')
|
|
@param('causal', (False, True))
|
|
@param('softclamp_value', (50., None))
|
|
@param('num_agent_tokens', (0, 1))
|
|
@param('causal_block_size', (1, 8))
|
|
@param('block_size_per_special', (1, 8))
|
|
@param('special_attend_only_itself', (False, True))
|
|
def test_attend_factory(
|
|
causal,
|
|
softclamp_value,
|
|
num_agent_tokens,
|
|
causal_block_size,
|
|
block_size_per_special,
|
|
special_attend_only_itself
|
|
):
|
|
|
|
from dreamer4.dreamer4 import get_attend_fn
|
|
|
|
q = torch.randn(1, 8, 1024, 512).cuda()
|
|
k = torch.randn(1, 4, 1024, 512).cuda()
|
|
v = torch.randn(1, 4, 1024, 512).cuda()
|
|
|
|
attend_kwargs = dict(
|
|
seq_len = 1024,
|
|
k_seq_len = 1024,
|
|
causal = causal,
|
|
causal_block_size = causal_block_size,
|
|
softclamp_value = softclamp_value,
|
|
device = q.device,
|
|
num_agent_tokens = num_agent_tokens,
|
|
block_size_per_special = block_size_per_special,
|
|
special_attend_only_itself = special_attend_only_itself
|
|
)
|
|
|
|
attend = get_attend_fn(True, **attend_kwargs)
|
|
flex_out = attend(q, k, v)
|
|
|
|
attend = get_attend_fn(False, **attend_kwargs)
|
|
out = attend(q, k, v)
|
|
|
|
assert torch.allclose(flex_out, out, atol = 1e-6)
|
|
|
|
def test_action_with_world_model():
|
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
|
|
|
tokenizer = VideoTokenizer(
|
|
512,
|
|
dim_latent = 32,
|
|
patch_size = 32,
|
|
encoder_depth = 4,
|
|
decoder_depth = 4,
|
|
attn_heads = 8,
|
|
image_height = 256,
|
|
image_width = 256,
|
|
attn_kwargs = dict(
|
|
query_heads = 16
|
|
)
|
|
)
|
|
|
|
dynamics = DynamicsWorldModel(
|
|
512,
|
|
num_agents = 1,
|
|
video_tokenizer = tokenizer,
|
|
dim_latent = 32,
|
|
depth = 4,
|
|
num_discrete_actions = 4
|
|
)
|
|
|
|
rewards = torch.randn(1, 4)
|
|
discrete_actions = torch.randint(0, 4, (1, 4, 1))
|
|
|
|
gen = dynamics.generate(
|
|
16,
|
|
batch_size = 4,
|
|
return_rewards_per_frame = True,
|
|
return_agent_actions = True,
|
|
return_log_probs_and_values = True
|
|
)
|
|
|
|
assert gen.video.shape == (4, 3, 16, 256, 256)
|
|
assert gen.rewards.shape == (4, 16)
|
|
|
|
discrete_actions, continuous_actions = gen.actions
|
|
|
|
assert discrete_actions.shape == (4, 16, 1)
|
|
assert continuous_actions is None
|
|
|
|
discrete_log_probs, _ = gen.log_probs
|
|
|
|
assert discrete_log_probs.shape == (4, 16, 1)
|
|
assert gen.values.shape == (4, 16)
|
|
|
|
# take a reinforcement learning step
|
|
|
|
actor_loss, critic_loss = dynamics.learn_from_experience(gen)
|
|
|
|
actor_loss.backward(retain_graph = True)
|
|
critic_loss.backward()
|
|
|
|
def test_action_embedder():
|
|
from dreamer4.dreamer4 import ActionEmbedder
|
|
|
|
# 1 discrete action with 4 choices
|
|
|
|
embedder = ActionEmbedder(
|
|
512,
|
|
num_discrete_actions = 4
|
|
)
|
|
|
|
actions = torch.randint(0, 4, (2, 3, 1))
|
|
action_embed = embedder(discrete_actions = actions)
|
|
|
|
assert action_embed.shape == (2, 3, 512)
|
|
|
|
# 2 discrete actions with 4 choices each
|
|
|
|
embedder = ActionEmbedder(
|
|
512,
|
|
num_discrete_actions = (4, 4)
|
|
)
|
|
|
|
actions = torch.randint(0, 4, (2, 3, 2))
|
|
action_embed = embedder(discrete_actions = actions)
|
|
|
|
assert action_embed.shape == (2, 3, 512)
|
|
|
|
# picking out only the second discrete action
|
|
|
|
actions = torch.randint(0, 4, (2, 3, 1))
|
|
action_embed = embedder(discrete_actions = actions, discrete_action_types = 1)
|
|
|
|
assert action_embed.shape == (2, 3, 512)
|
|
|
|
# 2 continuous actions
|
|
|
|
embedder = ActionEmbedder(
|
|
512,
|
|
num_continuous_actions = 2,
|
|
continuous_norm_stats = ((0., 2.), (1., 1.)) # (mean, std) for normalizing each action
|
|
)
|
|
|
|
actions = torch.randn((2, 3, 2))
|
|
action_embed = embedder(continuous_actions = actions)
|
|
|
|
assert action_embed.shape == (2, 3, 512)
|
|
|
|
# 2 discrete actions with 4 choices each and 2 continuous actions
|
|
|
|
embedder = ActionEmbedder(
|
|
512,
|
|
num_discrete_actions = (4, 4),
|
|
num_continuous_actions = 2
|
|
)
|
|
|
|
discrete_actions = torch.randint(0, 4, (2, 3, 2))
|
|
continuous_actions = torch.randn(2, 3, 2)
|
|
|
|
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions)
|
|
assert action_embed.shape == (2, 3, 512)
|
|
|
|
# picking out one discrete and one continuous
|
|
|
|
discrete_actions = torch.randint(0, 4, (2, 3, 1))
|
|
continuous_actions = torch.randn(2, 3, 1)
|
|
|
|
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions, discrete_action_types = 1, continuous_action_types = 0)
|
|
|
|
assert action_embed.shape == (2, 3, 512)
|
|
|
|
# unembed
|
|
|
|
embedder = ActionEmbedder(
|
|
512,
|
|
num_discrete_actions = (4, 4),
|
|
num_continuous_actions = 2,
|
|
can_unembed = True
|
|
)
|
|
|
|
discrete_actions = torch.randint(0, 4, (2, 3, 2))
|
|
continuous_actions = torch.randn(2, 3, 2)
|
|
|
|
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions)
|
|
|
|
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed)
|
|
|
|
assert discrete_logits.shape == (2, 3, 8)
|
|
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
|
|
|
|
# return discrete split by number of actions
|
|
|
|
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
|
|
assert discrete_logits[0].shape == discrete_logits[1].shape == (2, 3, 4)
|
|
|
|
# unembed subset of actions
|
|
|
|
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, discrete_action_types = 1, continuous_action_types = 0)
|
|
|
|
assert discrete_logits.shape == (2, 3, 4)
|
|
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
|
|
|
|
# sample actions
|
|
|
|
sampled_discrete_actions, sampled_continuous_actions = embedder.sample(action_embed, discrete_action_types = 1, continuous_action_types = 0)
|
|
|
|
assert sampled_discrete_actions.shape == (2, 3, 1)
|
|
assert sampled_continuous_actions.shape == (2, 3, 1)
|
|
|
|
# log probs
|
|
|
|
assert discrete_logits.shape == (2, 3, 4)
|
|
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
|
|
|
|
discrete_log_probs, continuous_log_probs = embedder.log_probs(
|
|
action_embed,
|
|
discrete_targets = discrete_actions,
|
|
continuous_targets = continuous_actions,
|
|
parallel_discrete_calc = False
|
|
)
|
|
|
|
assert discrete_log_probs.shape == (2, 3, 2)
|
|
assert continuous_log_probs.shape == (2, 3, 2)
|
|
|
|
_, (discrete_entropies, continuous_entropies) = embedder.log_probs(
|
|
action_embed,
|
|
discrete_targets = discrete_actions,
|
|
continuous_targets = continuous_actions,
|
|
parallel_discrete_calc = True,
|
|
return_entropies = True
|
|
)
|
|
|
|
assert discrete_entropies.shape == (2, 3, 2)
|
|
assert continuous_entropies.shape == (2, 3, 2)
|
|
|
|
parallel_discrete_log_probs, _ = embedder.log_probs(
|
|
action_embed,
|
|
discrete_targets = discrete_actions,
|
|
continuous_targets = continuous_actions,
|
|
parallel_discrete_calc = True
|
|
)
|
|
|
|
assert torch.allclose(discrete_log_probs, parallel_discrete_log_probs, atol = 1e-5)
|
|
|
|
def test_mtp():
|
|
from dreamer4.dreamer4 import create_multi_token_prediction_targets
|
|
|
|
rewards = torch.randn(3, 16) # (b t)
|
|
|
|
reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead
|
|
|
|
assert reward_targets.shape == (3, 15, 3)
|
|
assert mask.shape == (3, 15, 3)
|
|
|
|
actions = torch.randint(0, 10, (3, 16, 2))
|
|
action_targets, mask = create_multi_token_prediction_targets(actions, 3)
|
|
|
|
assert action_targets.shape == (3, 15, 3, 2)
|
|
assert mask.shape == (3, 15, 3)
|
|
|
|
from dreamer4.dreamer4 import ActionEmbedder
|
|
|
|
embedder = ActionEmbedder(
|
|
512,
|
|
num_discrete_actions = (4, 4),
|
|
num_continuous_actions = 2,
|
|
can_unembed = True,
|
|
num_unembed_preds = 8
|
|
)
|
|
|
|
discrete_actions = torch.randint(0, 4, (2, 3, 2))
|
|
continuous_actions = torch.randn(2, 3, 2)
|
|
|
|
action_embed = torch.randn(2, 16, 512)
|
|
discrete_logits, continuous_logits = embedder.unembed(action_embed)
|
|
|
|
assert discrete_logits.shape == (8, 2, 16, 8)
|
|
|
|
discrete_logits, continuous_logits = embedder.unembed(action_embed, pred_head_index = 0)
|
|
|
|
assert discrete_logits.shape == (2, 16, 8)
|
|
|
|
def test_loss_normalizer():
|
|
from torch import cat
|
|
from dreamer4.dreamer4 import LossNormalizer
|
|
|
|
loss_normalizer = LossNormalizer(4, beta = 0.)
|
|
|
|
losses = torch.rand(4)
|
|
|
|
_ = loss_normalizer(losses)
|
|
normed_losses = loss_normalizer(losses)
|
|
|
|
assert (normed_losses == 1.).all()
|
|
|
|
def test_tokenizer_trainer():
|
|
from dreamer4.trainers import VideoTokenizerTrainer
|
|
from dreamer4.dreamer4 import VideoTokenizer
|
|
from torch.utils.data import Dataset
|
|
|
|
class MockDataset(Dataset):
|
|
def __len__(self):
|
|
return 2
|
|
|
|
def __getitem__(self, idx):
|
|
return torch.randn(3, 2, 64, 64)
|
|
|
|
dataset = MockDataset()
|
|
|
|
tokenizer = VideoTokenizer(
|
|
16,
|
|
encoder_depth = 1,
|
|
decoder_depth = 1,
|
|
time_block_every = 1,
|
|
dim_latent = 16,
|
|
patch_size = 32,
|
|
attn_dim_head = 16,
|
|
num_latent_tokens = 4
|
|
)
|
|
|
|
trainer = VideoTokenizerTrainer(
|
|
tokenizer,
|
|
dataset = dataset,
|
|
num_train_steps = 1,
|
|
batch_size = 1,
|
|
cpu = True,
|
|
max_grad_norm = 0.5
|
|
)
|
|
|
|
trainer()
|
|
|
|
@param('with_actions', (True, False))
|
|
@param('with_rewards', (True, False))
|
|
def test_bc_trainer(
|
|
with_actions,
|
|
with_rewards
|
|
):
|
|
from dreamer4.trainers import BehaviorCloneTrainer
|
|
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
class MockDataset(Dataset):
|
|
def __len__(self):
|
|
return 2
|
|
|
|
def __getitem__(self, idx):
|
|
state = torch.randn(3, 2, 64, 64)
|
|
|
|
pkg = dict(video = state)
|
|
|
|
if with_actions:
|
|
pkg.update(discrete_actions = torch.randint(0, 4, (2, 1)))
|
|
|
|
if with_rewards:
|
|
pkg.update(rewards = torch.randn(2,))
|
|
|
|
return pkg
|
|
|
|
dataset = MockDataset()
|
|
|
|
tokenizer = VideoTokenizer(
|
|
16,
|
|
encoder_depth = 1,
|
|
decoder_depth = 1,
|
|
time_block_every = 1,
|
|
dim_latent = 16,
|
|
patch_size = 32,
|
|
attn_dim_head = 16,
|
|
num_latent_tokens = 1
|
|
)
|
|
|
|
world_model = DynamicsWorldModel(
|
|
video_tokenizer = tokenizer,
|
|
dim = 16,
|
|
dim_latent = 16,
|
|
max_steps = 64,
|
|
num_tasks = 4,
|
|
num_latent_tokens = 1,
|
|
depth = 1,
|
|
time_block_every = 1,
|
|
num_spatial_tokens = 1,
|
|
pred_orig_latent = True,
|
|
num_discrete_actions = 4,
|
|
attn_dim_head = 16,
|
|
prob_no_shortcut_train = 0.1,
|
|
num_residual_streams = 1
|
|
)
|
|
|
|
trainer = BehaviorCloneTrainer(
|
|
world_model,
|
|
dataset = dataset,
|
|
batch_size = 1,
|
|
num_train_steps = 1,
|
|
cpu = True
|
|
)
|
|
|
|
trainer()
|
|
|
|
def test_dream_trainer():
|
|
from dreamer4.dreamer4 import DynamicsWorldModel
|
|
|
|
world_model = DynamicsWorldModel(
|
|
dim = 16,
|
|
dim_latent = 16,
|
|
max_steps = 64,
|
|
num_tasks = 4,
|
|
num_latent_tokens = 1,
|
|
depth = 1,
|
|
time_block_every = 1,
|
|
num_spatial_tokens = 1,
|
|
pred_orig_latent = True,
|
|
num_discrete_actions = 4,
|
|
attn_dim_head = 16,
|
|
prob_no_shortcut_train = 0.1,
|
|
num_residual_streams = 1
|
|
)
|
|
|
|
# training from self-generations (dreams)
|
|
|
|
from dreamer4.trainers import DreamTrainer
|
|
|
|
dream_trainer = DreamTrainer(
|
|
world_model,
|
|
batch_size = 2,
|
|
num_train_steps = 1,
|
|
cpu = True,
|
|
)
|
|
|
|
dream_trainer()
|
|
|
|
def test_cache_generate():
|
|
from dreamer4.dreamer4 import DynamicsWorldModel
|
|
|
|
dynamics = DynamicsWorldModel(
|
|
dim = 16,
|
|
dim_latent = 16,
|
|
max_steps = 64,
|
|
num_tasks = 4,
|
|
num_latent_tokens = 4,
|
|
depth = 1,
|
|
time_block_every = 1,
|
|
num_spatial_tokens = 1,
|
|
pred_orig_latent = True,
|
|
num_discrete_actions = 4,
|
|
attn_dim_head = 16,
|
|
prob_no_shortcut_train = 0.1,
|
|
num_residual_streams = 1
|
|
)
|
|
|
|
generated, time_kv_cache = dynamics.generate(1, return_time_kv_cache = True)
|
|
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
|
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
|
|
|
def test_online_rl():
|
|
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
|
|
|
tokenizer = VideoTokenizer(
|
|
16,
|
|
encoder_depth = 1,
|
|
decoder_depth = 1,
|
|
time_block_every = 1,
|
|
dim_latent = 16,
|
|
patch_size = 32,
|
|
attn_dim_head = 16,
|
|
num_latent_tokens = 1
|
|
)
|
|
|
|
world_model_and_policy = DynamicsWorldModel(
|
|
video_tokenizer = tokenizer,
|
|
dim = 16,
|
|
dim_latent = 16,
|
|
max_steps = 64,
|
|
num_tasks = 4,
|
|
num_latent_tokens = 1,
|
|
depth = 1,
|
|
time_block_every = 1,
|
|
num_spatial_tokens = 1,
|
|
pred_orig_latent = True,
|
|
num_discrete_actions = 4,
|
|
attn_dim_head = 16,
|
|
prob_no_shortcut_train = 0.1,
|
|
num_residual_streams = 1
|
|
)
|
|
|
|
from dreamer4.mocks import MockEnv
|
|
mock_env = MockEnv((256, 256), vectorized = False, num_envs = 4)
|
|
|
|
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16)
|
|
|
|
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience)
|
|
|
|
actor_loss.backward()
|
|
critic_loss.backward()
|