fix the rnn time caching issue
This commit is contained in:
parent
ac1c12f743
commit
690ecf07dc
@ -1,6 +1,7 @@
|
||||
from dreamer4.dreamer4 import (
|
||||
VideoTokenizer,
|
||||
DynamicsWorldModel
|
||||
DynamicsWorldModel,
|
||||
AxialSpaceTimeTransformer
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Callable
|
||||
|
||||
import math
|
||||
from math import ceil, log2
|
||||
@ -76,7 +77,7 @@ WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_
|
||||
|
||||
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
|
||||
|
||||
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs'))
|
||||
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs', 'next_rnn_hiddens'))
|
||||
|
||||
MaybeTensor = Tensor | None
|
||||
|
||||
@ -1453,6 +1454,29 @@ class SwiGLUFeedforward(Module):
|
||||
|
||||
return self.proj_out(x)
|
||||
|
||||
# rnn
|
||||
|
||||
class GRULayer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = nn.RMSNorm(dim)
|
||||
self.gru = nn.GRU(dim, dim_out, batch_first = True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
prev_hiddens = None
|
||||
):
|
||||
x = self.norm(x)
|
||||
|
||||
x, hiddens = self.gru(x, prev_hiddens)
|
||||
|
||||
return x, hiddens
|
||||
|
||||
# axial space time transformer
|
||||
|
||||
class AxialSpaceTimeTransformer(Module):
|
||||
@ -1531,7 +1555,7 @@ class AxialSpaceTimeTransformer(Module):
|
||||
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
||||
]))
|
||||
|
||||
rnn_layers.append(hyper_conn(branch = nn.Sequential(nn.RMSNorm(dim), nn.GRU(dim, dim, batch_first = True))) if is_time_block and rnn_time else None)
|
||||
rnn_layers.append(hyper_conn(branch = GRULayer(dim, dim)) if is_time_block and rnn_time else None)
|
||||
|
||||
self.layers = ModuleList(layers)
|
||||
self.rnn_layers = ModuleList(rnn_layers)
|
||||
@ -1557,8 +1581,8 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens, # (b t s d)
|
||||
kv_cache: Tensor | None = None, # (y 2 b h t d)
|
||||
tokens, # (b t s d)
|
||||
cache: TransformerIntermediates | None = None,
|
||||
return_intermediates = False
|
||||
|
||||
): # (b t s d) | (y 2 b h t d)
|
||||
@ -1567,6 +1591,14 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
assert tokens.ndim == 4
|
||||
|
||||
# destruct intermediates to cache for attention and rnn respectively
|
||||
|
||||
kv_cache = rnn_prev_hiddens = None
|
||||
|
||||
if exists(cache):
|
||||
kv_cache = cache.next_kv_cache
|
||||
rnn_prev_hiddens = cache.next_rnn_hiddens
|
||||
|
||||
# attend functions for space and time
|
||||
|
||||
has_kv_cache = exists(kv_cache)
|
||||
@ -1581,6 +1613,7 @@ class AxialSpaceTimeTransformer(Module):
|
||||
# prepare cache
|
||||
|
||||
time_attn_kv_caches = []
|
||||
rnn_hiddens = []
|
||||
|
||||
if has_kv_cache:
|
||||
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
||||
@ -1595,6 +1628,10 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
iter_kv_cache = iter(kv_cache)
|
||||
|
||||
rnn_prev_hiddens = default(rnn_prev_hiddens, (None,))
|
||||
|
||||
iter_rnn_prev_hiddens = iter(rnn_prev_hiddens)
|
||||
|
||||
# rotary
|
||||
|
||||
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
||||
@ -1625,10 +1662,12 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
tokens, inverse_pack_batch = pack_one(tokens, '* t d')
|
||||
|
||||
tokens, rnn_hiddens = maybe_rnn(tokens) # todo, handle rnn cache
|
||||
tokens, layer_rnn_hiddens = maybe_rnn(tokens, next(iter_rnn_prev_hiddens, None)) # todo, handle rnn cache
|
||||
|
||||
tokens = inverse_pack_batch(tokens)
|
||||
|
||||
rnn_hiddens.append(layer_rnn_hiddens)
|
||||
|
||||
# when is a axial time attention block, should be causal
|
||||
|
||||
attend_fn = time_attend if layer_is_time else space_attend
|
||||
@ -1685,7 +1724,8 @@ class AxialSpaceTimeTransformer(Module):
|
||||
intermediates = TransformerIntermediates(
|
||||
stack(time_attn_kv_caches),
|
||||
safe_stack(normed_time_attn_inputs),
|
||||
safe_stack(normed_space_attn_inputs)
|
||||
safe_stack(normed_space_attn_inputs),
|
||||
safe_stack(rnn_hiddens)
|
||||
)
|
||||
|
||||
return out, intermediates
|
||||
@ -1717,11 +1757,6 @@ class VideoTokenizer(Module):
|
||||
encoder_add_decor_aux_loss = False,
|
||||
decor_auxx_loss_weight = 0.1,
|
||||
decorr_sample_frac = 0.25,
|
||||
nd_rotary_kwargs: dict = dict(
|
||||
rope_min_freq = 1.,
|
||||
rope_max_freq = 10000.,
|
||||
rope_p_zero_freqs = 0.
|
||||
),
|
||||
num_residual_streams = 1,
|
||||
):
|
||||
super().__init__()
|
||||
@ -1938,7 +1973,7 @@ class VideoTokenizer(Module):
|
||||
|
||||
# encoder attention
|
||||
|
||||
tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
|
||||
tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs, _) = self.encoder_transformer(tokens, return_intermediates = True)
|
||||
|
||||
# latent bottleneck
|
||||
|
||||
@ -2011,15 +2046,15 @@ class DynamicsWorldModel(Module):
|
||||
attn_dim_head = 64,
|
||||
attn_softclamp_value = 50.,
|
||||
ff_kwargs: dict = dict(),
|
||||
use_time_rnn = True,
|
||||
loss_weight_fn: Callable = ramp_weight,
|
||||
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
||||
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
||||
add_reward_embed_to_agent_token = False,
|
||||
add_reward_embed_dropout = 0.1,
|
||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||
num_continuous_actions = 0,
|
||||
continuous_norm_stats = None,
|
||||
multi_token_pred_len = 8,
|
||||
multi_token_pred_len = 8, # they do multi-token prediction of 8 steps forward
|
||||
value_head_mlp_depth = 3,
|
||||
policy_head_mlp_depth = 3,
|
||||
latent_flow_loss_weight = 1.,
|
||||
@ -2229,7 +2264,7 @@ class DynamicsWorldModel(Module):
|
||||
num_special_spatial_tokens = num_agents,
|
||||
time_block_every = time_block_every,
|
||||
final_norm = False,
|
||||
rnn_time = False,
|
||||
rnn_time = use_time_rnn,
|
||||
**transformer_kwargs
|
||||
)
|
||||
|
||||
@ -2376,7 +2411,7 @@ class DynamicsWorldModel(Module):
|
||||
step_size = 4,
|
||||
max_timesteps = 16,
|
||||
env_is_vectorized = False,
|
||||
use_time_kv_cache = True,
|
||||
use_time_cache = True,
|
||||
store_agent_embed = True,
|
||||
store_old_action_unembeds = True,
|
||||
):
|
||||
@ -2415,7 +2450,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# maybe time kv cache
|
||||
|
||||
time_kv_cache = None
|
||||
time_cache = None
|
||||
|
||||
step_index = 0
|
||||
|
||||
@ -2424,14 +2459,14 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
latents = self.video_tokenizer(video, return_latents = True)
|
||||
|
||||
_, (agent_embed, next_time_kv_cache) = self.forward(
|
||||
_, (agent_embed, next_time_cache) = self.forward(
|
||||
latents = latents,
|
||||
signal_levels = self.max_steps - 1,
|
||||
step_sizes = step_size,
|
||||
rewards = rewards,
|
||||
discrete_actions = discrete_actions,
|
||||
continuous_actions = continuous_actions,
|
||||
time_kv_cache = time_kv_cache,
|
||||
time_cache = time_cache,
|
||||
latent_is_noised = True,
|
||||
return_pred_only = True,
|
||||
return_intermediates = True
|
||||
@ -2439,8 +2474,8 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# time kv cache
|
||||
|
||||
if use_time_kv_cache:
|
||||
time_kv_cache = next_time_kv_cache
|
||||
if use_time_cache:
|
||||
time_cache = next_time_cache
|
||||
|
||||
# get one agent
|
||||
|
||||
@ -2832,13 +2867,13 @@ class DynamicsWorldModel(Module):
|
||||
image_width = None,
|
||||
return_decoded_video = None,
|
||||
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
||||
time_kv_cache: Tensor | None = None,
|
||||
use_time_kv_cache = True,
|
||||
time_cache: Tensor | None = None,
|
||||
use_time_cache = True,
|
||||
return_rewards_per_frame = False,
|
||||
return_agent_actions = False,
|
||||
return_log_probs_and_values = False,
|
||||
return_for_policy_optimization = False,
|
||||
return_time_kv_cache = False,
|
||||
return_time_cache = False,
|
||||
store_agent_embed = True,
|
||||
store_old_action_unembeds = True
|
||||
|
||||
@ -2927,7 +2962,7 @@ class DynamicsWorldModel(Module):
|
||||
# (2) decoding anything off agent embedding (rewards, actions, etc)
|
||||
|
||||
take_extra_step = (
|
||||
use_time_kv_cache or
|
||||
use_time_cache or
|
||||
return_rewards_per_frame or
|
||||
store_agent_embed or
|
||||
return_agent_actions
|
||||
@ -2968,7 +3003,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
||||
|
||||
pred, (agent_embed, next_time_kv_cache) = self.forward(
|
||||
pred, (agent_embed, next_time_cache) = self.forward(
|
||||
latents = noised_latent_with_context,
|
||||
signal_levels = signal_levels_with_context,
|
||||
step_sizes = step_size,
|
||||
@ -2978,15 +3013,15 @@ class DynamicsWorldModel(Module):
|
||||
discrete_actions = decoded_discrete_actions,
|
||||
continuous_actions = decoded_continuous_actions,
|
||||
proprio = noised_proprio_with_context,
|
||||
time_kv_cache = time_kv_cache,
|
||||
time_cache = time_cache,
|
||||
latent_is_noised = True,
|
||||
latent_has_view_dim = True,
|
||||
return_pred_only = True,
|
||||
return_intermediates = True,
|
||||
)
|
||||
|
||||
if use_time_kv_cache and is_last_step:
|
||||
time_kv_cache = next_time_kv_cache
|
||||
if use_time_cache and is_last_step:
|
||||
time_cache = next_time_cache
|
||||
|
||||
# early break if taking an extra step for agent embedding off cleaned latents for decoding
|
||||
|
||||
@ -3135,10 +3170,10 @@ class DynamicsWorldModel(Module):
|
||||
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
|
||||
out = video if return_decoded_video else latents
|
||||
|
||||
if not return_time_kv_cache:
|
||||
if not return_time_cache:
|
||||
return out
|
||||
|
||||
return out, time_kv_cache
|
||||
return out, time_cache
|
||||
|
||||
# returning agent actions, rewards, and log probs + values for policy optimization
|
||||
|
||||
@ -3168,10 +3203,10 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
gen.values = decoded_values
|
||||
|
||||
if not return_time_kv_cache:
|
||||
if not return_time_cache:
|
||||
return gen
|
||||
|
||||
return gen, time_kv_cache
|
||||
return gen, time_cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -3190,7 +3225,7 @@ class DynamicsWorldModel(Module):
|
||||
discrete_action_types = None, # (na)
|
||||
continuous_action_types = None, # (na)
|
||||
proprio = None, # (b t dp)
|
||||
time_kv_cache = None,
|
||||
time_cache = None,
|
||||
return_pred_only = False,
|
||||
latent_is_noised = False,
|
||||
return_all_losses = False,
|
||||
@ -3410,7 +3445,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||
|
||||
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
|
||||
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False):
|
||||
|
||||
# latents to spatial tokens
|
||||
|
||||
@ -3463,7 +3498,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# attention
|
||||
|
||||
tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
|
||||
tokens, intermediates = self.transformer(tokens, cache = time_cache, return_intermediates = True)
|
||||
|
||||
# unpack
|
||||
|
||||
@ -3487,10 +3522,10 @@ class DynamicsWorldModel(Module):
|
||||
if not return_agent_tokens:
|
||||
return pred
|
||||
|
||||
if not return_time_kv_cache:
|
||||
if not return_time_cache:
|
||||
return pred, agent_tokens
|
||||
|
||||
return pred, (agent_tokens, next_time_kv_cache)
|
||||
return pred, (agent_tokens, intermediates)
|
||||
|
||||
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
||||
|
||||
@ -3498,13 +3533,13 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# forward the network
|
||||
|
||||
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
|
||||
pred, (encoded_agent_tokens, intermediates) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_cache = True)
|
||||
|
||||
if return_pred_only:
|
||||
if not return_intermediates:
|
||||
return pred
|
||||
|
||||
return pred, (encoded_agent_tokens, next_time_kv_cache)
|
||||
return pred, (encoded_agent_tokens, intermediates)
|
||||
|
||||
# pack the predictions to calculate flow for different modalities all at once
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.1.18"
|
||||
version = "0.1.19"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -15,7 +15,7 @@ def exists(v):
|
||||
@param('condition_on_actions', (False, True))
|
||||
@param('num_residual_streams', (1, 4))
|
||||
@param('add_reward_embed_to_agent_token', (False, True))
|
||||
@param('use_time_kv_cache', (False, True))
|
||||
@param('use_time_cache', (False, True))
|
||||
@param('var_len', (False, True))
|
||||
def test_e2e(
|
||||
pred_orig_latent,
|
||||
@ -28,7 +28,7 @@ def test_e2e(
|
||||
condition_on_actions,
|
||||
num_residual_streams,
|
||||
add_reward_embed_to_agent_token,
|
||||
use_time_kv_cache,
|
||||
use_time_cache,
|
||||
var_len
|
||||
):
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
@ -123,7 +123,7 @@ def test_e2e(
|
||||
image_width = 128,
|
||||
batch_size = 2,
|
||||
return_rewards_per_frame = True,
|
||||
use_time_kv_cache = use_time_kv_cache
|
||||
use_time_cache = use_time_cache
|
||||
)
|
||||
|
||||
assert generations.video.shape == (2, 3, 10, 128, 128)
|
||||
@ -617,9 +617,9 @@ def test_cache_generate():
|
||||
num_residual_streams = 1
|
||||
)
|
||||
|
||||
generated, time_kv_cache = dynamics.generate(1, return_time_kv_cache = True)
|
||||
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
||||
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
||||
generated, time_cache = dynamics.generate(1, return_time_cache = True)
|
||||
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
|
||||
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
|
||||
|
||||
@param('vectorized', (False, True))
|
||||
@param('use_pmpo', (False, True))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user