fix the rnn time caching issue
This commit is contained in:
parent
ac1c12f743
commit
690ecf07dc
@ -1,6 +1,7 @@
|
|||||||
from dreamer4.dreamer4 import (
|
from dreamer4.dreamer4 import (
|
||||||
VideoTokenizer,
|
VideoTokenizer,
|
||||||
DynamicsWorldModel
|
DynamicsWorldModel,
|
||||||
|
AxialSpaceTimeTransformer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from math import ceil, log2
|
from math import ceil, log2
|
||||||
@ -76,7 +77,7 @@ WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_
|
|||||||
|
|
||||||
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
|
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
|
||||||
|
|
||||||
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs'))
|
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs', 'next_rnn_hiddens'))
|
||||||
|
|
||||||
MaybeTensor = Tensor | None
|
MaybeTensor = Tensor | None
|
||||||
|
|
||||||
@ -1453,6 +1454,29 @@ class SwiGLUFeedforward(Module):
|
|||||||
|
|
||||||
return self.proj_out(x)
|
return self.proj_out(x)
|
||||||
|
|
||||||
|
# rnn
|
||||||
|
|
||||||
|
class GRULayer(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_out
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.RMSNorm(dim)
|
||||||
|
self.gru = nn.GRU(dim, dim_out, batch_first = True)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
prev_hiddens = None
|
||||||
|
):
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
x, hiddens = self.gru(x, prev_hiddens)
|
||||||
|
|
||||||
|
return x, hiddens
|
||||||
|
|
||||||
# axial space time transformer
|
# axial space time transformer
|
||||||
|
|
||||||
class AxialSpaceTimeTransformer(Module):
|
class AxialSpaceTimeTransformer(Module):
|
||||||
@ -1531,7 +1555,7 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
||||||
]))
|
]))
|
||||||
|
|
||||||
rnn_layers.append(hyper_conn(branch = nn.Sequential(nn.RMSNorm(dim), nn.GRU(dim, dim, batch_first = True))) if is_time_block and rnn_time else None)
|
rnn_layers.append(hyper_conn(branch = GRULayer(dim, dim)) if is_time_block and rnn_time else None)
|
||||||
|
|
||||||
self.layers = ModuleList(layers)
|
self.layers = ModuleList(layers)
|
||||||
self.rnn_layers = ModuleList(rnn_layers)
|
self.rnn_layers = ModuleList(rnn_layers)
|
||||||
@ -1557,8 +1581,8 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
tokens, # (b t s d)
|
tokens, # (b t s d)
|
||||||
kv_cache: Tensor | None = None, # (y 2 b h t d)
|
cache: TransformerIntermediates | None = None,
|
||||||
return_intermediates = False
|
return_intermediates = False
|
||||||
|
|
||||||
): # (b t s d) | (y 2 b h t d)
|
): # (b t s d) | (y 2 b h t d)
|
||||||
@ -1567,6 +1591,14 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
|
|
||||||
assert tokens.ndim == 4
|
assert tokens.ndim == 4
|
||||||
|
|
||||||
|
# destruct intermediates to cache for attention and rnn respectively
|
||||||
|
|
||||||
|
kv_cache = rnn_prev_hiddens = None
|
||||||
|
|
||||||
|
if exists(cache):
|
||||||
|
kv_cache = cache.next_kv_cache
|
||||||
|
rnn_prev_hiddens = cache.next_rnn_hiddens
|
||||||
|
|
||||||
# attend functions for space and time
|
# attend functions for space and time
|
||||||
|
|
||||||
has_kv_cache = exists(kv_cache)
|
has_kv_cache = exists(kv_cache)
|
||||||
@ -1581,6 +1613,7 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
# prepare cache
|
# prepare cache
|
||||||
|
|
||||||
time_attn_kv_caches = []
|
time_attn_kv_caches = []
|
||||||
|
rnn_hiddens = []
|
||||||
|
|
||||||
if has_kv_cache:
|
if has_kv_cache:
|
||||||
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
||||||
@ -1595,6 +1628,10 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
|
|
||||||
iter_kv_cache = iter(kv_cache)
|
iter_kv_cache = iter(kv_cache)
|
||||||
|
|
||||||
|
rnn_prev_hiddens = default(rnn_prev_hiddens, (None,))
|
||||||
|
|
||||||
|
iter_rnn_prev_hiddens = iter(rnn_prev_hiddens)
|
||||||
|
|
||||||
# rotary
|
# rotary
|
||||||
|
|
||||||
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
||||||
@ -1625,10 +1662,12 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
|
|
||||||
tokens, inverse_pack_batch = pack_one(tokens, '* t d')
|
tokens, inverse_pack_batch = pack_one(tokens, '* t d')
|
||||||
|
|
||||||
tokens, rnn_hiddens = maybe_rnn(tokens) # todo, handle rnn cache
|
tokens, layer_rnn_hiddens = maybe_rnn(tokens, next(iter_rnn_prev_hiddens, None)) # todo, handle rnn cache
|
||||||
|
|
||||||
tokens = inverse_pack_batch(tokens)
|
tokens = inverse_pack_batch(tokens)
|
||||||
|
|
||||||
|
rnn_hiddens.append(layer_rnn_hiddens)
|
||||||
|
|
||||||
# when is a axial time attention block, should be causal
|
# when is a axial time attention block, should be causal
|
||||||
|
|
||||||
attend_fn = time_attend if layer_is_time else space_attend
|
attend_fn = time_attend if layer_is_time else space_attend
|
||||||
@ -1685,7 +1724,8 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
intermediates = TransformerIntermediates(
|
intermediates = TransformerIntermediates(
|
||||||
stack(time_attn_kv_caches),
|
stack(time_attn_kv_caches),
|
||||||
safe_stack(normed_time_attn_inputs),
|
safe_stack(normed_time_attn_inputs),
|
||||||
safe_stack(normed_space_attn_inputs)
|
safe_stack(normed_space_attn_inputs),
|
||||||
|
safe_stack(rnn_hiddens)
|
||||||
)
|
)
|
||||||
|
|
||||||
return out, intermediates
|
return out, intermediates
|
||||||
@ -1717,11 +1757,6 @@ class VideoTokenizer(Module):
|
|||||||
encoder_add_decor_aux_loss = False,
|
encoder_add_decor_aux_loss = False,
|
||||||
decor_auxx_loss_weight = 0.1,
|
decor_auxx_loss_weight = 0.1,
|
||||||
decorr_sample_frac = 0.25,
|
decorr_sample_frac = 0.25,
|
||||||
nd_rotary_kwargs: dict = dict(
|
|
||||||
rope_min_freq = 1.,
|
|
||||||
rope_max_freq = 10000.,
|
|
||||||
rope_p_zero_freqs = 0.
|
|
||||||
),
|
|
||||||
num_residual_streams = 1,
|
num_residual_streams = 1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -1938,7 +1973,7 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
# encoder attention
|
# encoder attention
|
||||||
|
|
||||||
tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
|
tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs, _) = self.encoder_transformer(tokens, return_intermediates = True)
|
||||||
|
|
||||||
# latent bottleneck
|
# latent bottleneck
|
||||||
|
|
||||||
@ -2011,15 +2046,15 @@ class DynamicsWorldModel(Module):
|
|||||||
attn_dim_head = 64,
|
attn_dim_head = 64,
|
||||||
attn_softclamp_value = 50.,
|
attn_softclamp_value = 50.,
|
||||||
ff_kwargs: dict = dict(),
|
ff_kwargs: dict = dict(),
|
||||||
|
use_time_rnn = True,
|
||||||
loss_weight_fn: Callable = ramp_weight,
|
loss_weight_fn: Callable = ramp_weight,
|
||||||
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
|
||||||
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
||||||
add_reward_embed_to_agent_token = False,
|
add_reward_embed_to_agent_token = False,
|
||||||
add_reward_embed_dropout = 0.1,
|
add_reward_embed_dropout = 0.1,
|
||||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||||
num_continuous_actions = 0,
|
num_continuous_actions = 0,
|
||||||
continuous_norm_stats = None,
|
continuous_norm_stats = None,
|
||||||
multi_token_pred_len = 8,
|
multi_token_pred_len = 8, # they do multi-token prediction of 8 steps forward
|
||||||
value_head_mlp_depth = 3,
|
value_head_mlp_depth = 3,
|
||||||
policy_head_mlp_depth = 3,
|
policy_head_mlp_depth = 3,
|
||||||
latent_flow_loss_weight = 1.,
|
latent_flow_loss_weight = 1.,
|
||||||
@ -2229,7 +2264,7 @@ class DynamicsWorldModel(Module):
|
|||||||
num_special_spatial_tokens = num_agents,
|
num_special_spatial_tokens = num_agents,
|
||||||
time_block_every = time_block_every,
|
time_block_every = time_block_every,
|
||||||
final_norm = False,
|
final_norm = False,
|
||||||
rnn_time = False,
|
rnn_time = use_time_rnn,
|
||||||
**transformer_kwargs
|
**transformer_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2376,7 +2411,7 @@ class DynamicsWorldModel(Module):
|
|||||||
step_size = 4,
|
step_size = 4,
|
||||||
max_timesteps = 16,
|
max_timesteps = 16,
|
||||||
env_is_vectorized = False,
|
env_is_vectorized = False,
|
||||||
use_time_kv_cache = True,
|
use_time_cache = True,
|
||||||
store_agent_embed = True,
|
store_agent_embed = True,
|
||||||
store_old_action_unembeds = True,
|
store_old_action_unembeds = True,
|
||||||
):
|
):
|
||||||
@ -2415,7 +2450,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# maybe time kv cache
|
# maybe time kv cache
|
||||||
|
|
||||||
time_kv_cache = None
|
time_cache = None
|
||||||
|
|
||||||
step_index = 0
|
step_index = 0
|
||||||
|
|
||||||
@ -2424,14 +2459,14 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
latents = self.video_tokenizer(video, return_latents = True)
|
latents = self.video_tokenizer(video, return_latents = True)
|
||||||
|
|
||||||
_, (agent_embed, next_time_kv_cache) = self.forward(
|
_, (agent_embed, next_time_cache) = self.forward(
|
||||||
latents = latents,
|
latents = latents,
|
||||||
signal_levels = self.max_steps - 1,
|
signal_levels = self.max_steps - 1,
|
||||||
step_sizes = step_size,
|
step_sizes = step_size,
|
||||||
rewards = rewards,
|
rewards = rewards,
|
||||||
discrete_actions = discrete_actions,
|
discrete_actions = discrete_actions,
|
||||||
continuous_actions = continuous_actions,
|
continuous_actions = continuous_actions,
|
||||||
time_kv_cache = time_kv_cache,
|
time_cache = time_cache,
|
||||||
latent_is_noised = True,
|
latent_is_noised = True,
|
||||||
return_pred_only = True,
|
return_pred_only = True,
|
||||||
return_intermediates = True
|
return_intermediates = True
|
||||||
@ -2439,8 +2474,8 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# time kv cache
|
# time kv cache
|
||||||
|
|
||||||
if use_time_kv_cache:
|
if use_time_cache:
|
||||||
time_kv_cache = next_time_kv_cache
|
time_cache = next_time_cache
|
||||||
|
|
||||||
# get one agent
|
# get one agent
|
||||||
|
|
||||||
@ -2832,13 +2867,13 @@ class DynamicsWorldModel(Module):
|
|||||||
image_width = None,
|
image_width = None,
|
||||||
return_decoded_video = None,
|
return_decoded_video = None,
|
||||||
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
||||||
time_kv_cache: Tensor | None = None,
|
time_cache: Tensor | None = None,
|
||||||
use_time_kv_cache = True,
|
use_time_cache = True,
|
||||||
return_rewards_per_frame = False,
|
return_rewards_per_frame = False,
|
||||||
return_agent_actions = False,
|
return_agent_actions = False,
|
||||||
return_log_probs_and_values = False,
|
return_log_probs_and_values = False,
|
||||||
return_for_policy_optimization = False,
|
return_for_policy_optimization = False,
|
||||||
return_time_kv_cache = False,
|
return_time_cache = False,
|
||||||
store_agent_embed = True,
|
store_agent_embed = True,
|
||||||
store_old_action_unembeds = True
|
store_old_action_unembeds = True
|
||||||
|
|
||||||
@ -2927,7 +2962,7 @@ class DynamicsWorldModel(Module):
|
|||||||
# (2) decoding anything off agent embedding (rewards, actions, etc)
|
# (2) decoding anything off agent embedding (rewards, actions, etc)
|
||||||
|
|
||||||
take_extra_step = (
|
take_extra_step = (
|
||||||
use_time_kv_cache or
|
use_time_cache or
|
||||||
return_rewards_per_frame or
|
return_rewards_per_frame or
|
||||||
store_agent_embed or
|
store_agent_embed or
|
||||||
return_agent_actions
|
return_agent_actions
|
||||||
@ -2968,7 +3003,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
||||||
|
|
||||||
pred, (agent_embed, next_time_kv_cache) = self.forward(
|
pred, (agent_embed, next_time_cache) = self.forward(
|
||||||
latents = noised_latent_with_context,
|
latents = noised_latent_with_context,
|
||||||
signal_levels = signal_levels_with_context,
|
signal_levels = signal_levels_with_context,
|
||||||
step_sizes = step_size,
|
step_sizes = step_size,
|
||||||
@ -2978,15 +3013,15 @@ class DynamicsWorldModel(Module):
|
|||||||
discrete_actions = decoded_discrete_actions,
|
discrete_actions = decoded_discrete_actions,
|
||||||
continuous_actions = decoded_continuous_actions,
|
continuous_actions = decoded_continuous_actions,
|
||||||
proprio = noised_proprio_with_context,
|
proprio = noised_proprio_with_context,
|
||||||
time_kv_cache = time_kv_cache,
|
time_cache = time_cache,
|
||||||
latent_is_noised = True,
|
latent_is_noised = True,
|
||||||
latent_has_view_dim = True,
|
latent_has_view_dim = True,
|
||||||
return_pred_only = True,
|
return_pred_only = True,
|
||||||
return_intermediates = True,
|
return_intermediates = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_time_kv_cache and is_last_step:
|
if use_time_cache and is_last_step:
|
||||||
time_kv_cache = next_time_kv_cache
|
time_cache = next_time_cache
|
||||||
|
|
||||||
# early break if taking an extra step for agent embedding off cleaned latents for decoding
|
# early break if taking an extra step for agent embedding off cleaned latents for decoding
|
||||||
|
|
||||||
@ -3135,10 +3170,10 @@ class DynamicsWorldModel(Module):
|
|||||||
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
|
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
|
||||||
out = video if return_decoded_video else latents
|
out = video if return_decoded_video else latents
|
||||||
|
|
||||||
if not return_time_kv_cache:
|
if not return_time_cache:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
return out, time_kv_cache
|
return out, time_cache
|
||||||
|
|
||||||
# returning agent actions, rewards, and log probs + values for policy optimization
|
# returning agent actions, rewards, and log probs + values for policy optimization
|
||||||
|
|
||||||
@ -3168,10 +3203,10 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
gen.values = decoded_values
|
gen.values = decoded_values
|
||||||
|
|
||||||
if not return_time_kv_cache:
|
if not return_time_cache:
|
||||||
return gen
|
return gen
|
||||||
|
|
||||||
return gen, time_kv_cache
|
return gen, time_cache
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -3190,7 +3225,7 @@ class DynamicsWorldModel(Module):
|
|||||||
discrete_action_types = None, # (na)
|
discrete_action_types = None, # (na)
|
||||||
continuous_action_types = None, # (na)
|
continuous_action_types = None, # (na)
|
||||||
proprio = None, # (b t dp)
|
proprio = None, # (b t dp)
|
||||||
time_kv_cache = None,
|
time_cache = None,
|
||||||
return_pred_only = False,
|
return_pred_only = False,
|
||||||
latent_is_noised = False,
|
latent_is_noised = False,
|
||||||
return_all_losses = False,
|
return_all_losses = False,
|
||||||
@ -3410,7 +3445,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||||
|
|
||||||
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
|
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False):
|
||||||
|
|
||||||
# latents to spatial tokens
|
# latents to spatial tokens
|
||||||
|
|
||||||
@ -3463,7 +3498,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
|
tokens, intermediates = self.transformer(tokens, cache = time_cache, return_intermediates = True)
|
||||||
|
|
||||||
# unpack
|
# unpack
|
||||||
|
|
||||||
@ -3487,10 +3522,10 @@ class DynamicsWorldModel(Module):
|
|||||||
if not return_agent_tokens:
|
if not return_agent_tokens:
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
if not return_time_kv_cache:
|
if not return_time_cache:
|
||||||
return pred, agent_tokens
|
return pred, agent_tokens
|
||||||
|
|
||||||
return pred, (agent_tokens, next_time_kv_cache)
|
return pred, (agent_tokens, intermediates)
|
||||||
|
|
||||||
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
||||||
|
|
||||||
@ -3498,13 +3533,13 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# forward the network
|
# forward the network
|
||||||
|
|
||||||
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
|
pred, (encoded_agent_tokens, intermediates) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_cache = True)
|
||||||
|
|
||||||
if return_pred_only:
|
if return_pred_only:
|
||||||
if not return_intermediates:
|
if not return_intermediates:
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
return pred, (encoded_agent_tokens, next_time_kv_cache)
|
return pred, (encoded_agent_tokens, intermediates)
|
||||||
|
|
||||||
# pack the predictions to calculate flow for different modalities all at once
|
# pack the predictions to calculate flow for different modalities all at once
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.1.18"
|
version = "0.1.19"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -15,7 +15,7 @@ def exists(v):
|
|||||||
@param('condition_on_actions', (False, True))
|
@param('condition_on_actions', (False, True))
|
||||||
@param('num_residual_streams', (1, 4))
|
@param('num_residual_streams', (1, 4))
|
||||||
@param('add_reward_embed_to_agent_token', (False, True))
|
@param('add_reward_embed_to_agent_token', (False, True))
|
||||||
@param('use_time_kv_cache', (False, True))
|
@param('use_time_cache', (False, True))
|
||||||
@param('var_len', (False, True))
|
@param('var_len', (False, True))
|
||||||
def test_e2e(
|
def test_e2e(
|
||||||
pred_orig_latent,
|
pred_orig_latent,
|
||||||
@ -28,7 +28,7 @@ def test_e2e(
|
|||||||
condition_on_actions,
|
condition_on_actions,
|
||||||
num_residual_streams,
|
num_residual_streams,
|
||||||
add_reward_embed_to_agent_token,
|
add_reward_embed_to_agent_token,
|
||||||
use_time_kv_cache,
|
use_time_cache,
|
||||||
var_len
|
var_len
|
||||||
):
|
):
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||||
@ -123,7 +123,7 @@ def test_e2e(
|
|||||||
image_width = 128,
|
image_width = 128,
|
||||||
batch_size = 2,
|
batch_size = 2,
|
||||||
return_rewards_per_frame = True,
|
return_rewards_per_frame = True,
|
||||||
use_time_kv_cache = use_time_kv_cache
|
use_time_cache = use_time_cache
|
||||||
)
|
)
|
||||||
|
|
||||||
assert generations.video.shape == (2, 3, 10, 128, 128)
|
assert generations.video.shape == (2, 3, 10, 128, 128)
|
||||||
@ -617,9 +617,9 @@ def test_cache_generate():
|
|||||||
num_residual_streams = 1
|
num_residual_streams = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
generated, time_kv_cache = dynamics.generate(1, return_time_kv_cache = True)
|
generated, time_cache = dynamics.generate(1, return_time_cache = True)
|
||||||
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
|
||||||
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
|
||||||
|
|
||||||
@param('vectorized', (False, True))
|
@param('vectorized', (False, True))
|
||||||
@param('use_pmpo', (False, True))
|
@param('use_pmpo', (False, True))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user