Compare commits

...

11 Commits
0.1.12 ... main

4 changed files with 343 additions and 87 deletions

View File

@ -1,6 +1,7 @@
from dreamer4.dreamer4 import ( from dreamer4.dreamer4 import (
VideoTokenizer, VideoTokenizer,
DynamicsWorldModel DynamicsWorldModel,
AxialSpaceTimeTransformer
) )

View File

@ -1,4 +1,5 @@
from __future__ import annotations from __future__ import annotations
from typing import Callable
import math import math
from math import ceil, log2 from math import ceil, log2
@ -70,13 +71,19 @@ except ImportError:
LinearNoBias = partial(Linear, bias = False) LinearNoBias = partial(Linear, bias = False)
VideoTokenizerIntermediates = namedtuple('VideoTokenizerIntermediates', ('losses', 'recon'))
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr')) TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions')) WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions', 'state_pred'))
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs')) AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs')) TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs', 'next_rnn_hiddens'))
Predictions = namedtuple('Predictions', ('flow', 'proprioception', 'state'))
Embeds = namedtuple('Embeds', ['agent', 'state_pred'])
MaybeTensor = Tensor | None MaybeTensor = Tensor | None
@ -95,7 +102,7 @@ class Experience:
lens: MaybeTensor = None lens: MaybeTensor = None
is_truncated: MaybeTensor = None is_truncated: MaybeTensor = None
agent_index: int = 0 agent_index: int = 0
is_from_world_model: bool = True is_from_world_model: bool | Tensor = True
def cpu(self): def cpu(self):
return self.to(torch.device('cpu')) return self.to(torch.device('cpu'))
@ -123,6 +130,9 @@ def combine_experiences(
if not exists(exp.is_truncated): if not exists(exp.is_truncated):
exp.is_truncated = full((batch,), True, device = device) exp.is_truncated = full((batch,), True, device = device)
if isinstance(exp.is_from_world_model, bool):
exp.is_from_world_model = tensor(exp.is_from_world_model)
# convert to dictionary # convert to dictionary
exps_dict = [asdict(exp) for exp in exps] exps_dict = [asdict(exp) for exp in exps]
@ -144,11 +154,15 @@ def combine_experiences(
for field_values in all_field_values: for field_values in all_field_values:
if is_tensor(first(field_values)): first_value = first(field_values)
if is_tensor(first_value):
field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2)) field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2))
new_field_value = cat(field_values) cat_or_stack = cat if first_value.ndim > 0 else stack
new_field_value = cat_or_stack(field_values)
else: else:
new_field_value = first(list(set(field_values))) new_field_value = first(list(set(field_values)))
@ -189,6 +203,13 @@ def sample_prob(prob):
def is_power_two(num): def is_power_two(num):
return log2(num).is_integer() return log2(num).is_integer()
def maybe(fn):
def inner(t, *args, **kwargs):
if not exists(t) or not exists(fn):
return None
return fn(t)
return inner
# tensor helpers # tensor helpers
def is_empty(t): def is_empty(t):
@ -1284,7 +1305,8 @@ class Attention(Module):
pre_rmsnorm = True, pre_rmsnorm = True,
gate_values = True, gate_values = True,
rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2 rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2
rmsnorm_key = True rmsnorm_key = True,
value_residual = True
): ):
super().__init__() super().__init__()
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity() self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
@ -1323,6 +1345,14 @@ class Attention(Module):
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity() self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity()
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity() self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
# value residual
self.to_learned_value_residual_mix = nn.Sequential(
nn.Linear(dim, heads),
Rearrange('b n h -> b h n 1'),
nn.Sigmoid()
) if value_residual else None
def muon_parameters(self): def muon_parameters(self):
# omit the queries and keys for now given what we learned from kimi 2 paper # omit the queries and keys for now given what we learned from kimi 2 paper
@ -1337,6 +1367,7 @@ class Attention(Module):
kv_cache = None, kv_cache = None,
return_intermediates = False, return_intermediates = False,
rotary_pos_emb = None, rotary_pos_emb = None,
residual_values = None, # (b n h d)
attend_fn: Callable | None = None attend_fn: Callable | None = None
): ):
tokens, inverse_packed_batch = pack_one(tokens, '* n d') tokens, inverse_packed_batch = pack_one(tokens, '* n d')
@ -1349,6 +1380,17 @@ class Attention(Module):
q, k, v = map(self.split_heads, (q, k, v)) q, k, v = map(self.split_heads, (q, k, v))
# handle maybe value residual
if exists(residual_values):
residual_values = rearrange(residual_values, '... n h d -> (...) h n d')
assert exists(self.to_learned_value_residual_mix)
learned_mix = self.to_learned_value_residual_mix(tokens)
v = v.lerp(residual_values, learned_mix)
# qk rmsnorm # qk rmsnorm
q = self.q_heads_rmsnorm(q) q = self.q_heads_rmsnorm(q)
@ -1425,6 +1467,29 @@ class SwiGLUFeedforward(Module):
return self.proj_out(x) return self.proj_out(x)
# rnn
class GRULayer(Module):
def __init__(
self,
dim,
dim_out
):
super().__init__()
self.norm = nn.RMSNorm(dim)
self.gru = nn.GRU(dim, dim_out, batch_first = True)
def forward(
self,
x,
prev_hiddens = None
):
x = self.norm(x)
x, hiddens = self.gru(x, prev_hiddens)
return x, hiddens
# axial space time transformer # axial space time transformer
class AxialSpaceTimeTransformer(Module): class AxialSpaceTimeTransformer(Module):
@ -1432,6 +1497,7 @@ class AxialSpaceTimeTransformer(Module):
self, self,
dim, dim,
depth, depth,
attn_heads = 8,
attn_dim_head = 64, attn_dim_head = 64,
attn_softclamp_value = 50., attn_softclamp_value = 50.,
time_block_every = 4, time_block_every = 4,
@ -1440,7 +1506,9 @@ class AxialSpaceTimeTransformer(Module):
num_residual_streams = 1, num_residual_streams = 1,
num_special_spatial_tokens = 1, num_special_spatial_tokens = 1,
special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything) special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
final_norm = True final_norm = True,
value_residual = True, # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS
rnn_time = True
): ):
super().__init__() super().__init__()
assert depth >= time_block_every, f'depth must be at least {time_block_every}' assert depth >= time_block_every, f'depth must be at least {time_block_every}'
@ -1461,6 +1529,24 @@ class AxialSpaceTimeTransformer(Module):
self.time_rotary = Rotary1D(attn_dim_head) self.time_rotary = Rotary1D(attn_dim_head)
# project initial for value residuals
self.value_residual = value_residual
if value_residual:
dim_inner = attn_dim_head * attn_heads
self.to_value_residual = nn.Sequential(
nn.RMSNorm(dim),
nn.Linear(dim, dim_inner, bias = False),
Rearrange('... (h d) -> ... h d', h = attn_heads)
)
# a gru layer across time
self.rnn_time = rnn_time
rnn_layers = []
# transformer # transformer
layers = [] layers = []
@ -1472,17 +1558,21 @@ class AxialSpaceTimeTransformer(Module):
is_time_block = divisible_by(layer_index, time_block_every) is_time_block = divisible_by(layer_index, time_block_every)
is_time.append(is_time_block) is_time.append(is_time_block)
rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity() rearrange_to_attend = Rearrange('b t s ... -> b s t ...') if is_time_block else Identity()
rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity() rearrange_from_attend = Rearrange('b s t ... -> b t s ...') if is_time_block else Identity()
layers.append(ModuleList([ layers.append(ModuleList([
rearrange_to_attend, rearrange_to_attend,
rearrange_from_attend, rearrange_from_attend,
hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)), hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, value_residual = value_residual, **attn_kwargs)),
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs)) hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
])) ]))
rnn_layers.append(hyper_conn(branch = GRULayer(dim, dim)) if is_time_block and rnn_time else None)
self.layers = ModuleList(layers) self.layers = ModuleList(layers)
self.rnn_layers = ModuleList(rnn_layers)
self.is_time = is_time self.is_time = is_time
# final norm # final norm
@ -1505,7 +1595,7 @@ class AxialSpaceTimeTransformer(Module):
def forward( def forward(
self, self,
tokens, # (b t s d) tokens, # (b t s d)
kv_cache: Tensor | None = None, # (y 2 b h t d) cache: TransformerIntermediates | None = None,
return_intermediates = False return_intermediates = False
): # (b t s d) | (y 2 b h t d) ): # (b t s d) | (y 2 b h t d)
@ -1514,6 +1604,14 @@ class AxialSpaceTimeTransformer(Module):
assert tokens.ndim == 4 assert tokens.ndim == 4
# destruct intermediates to cache for attention and rnn respectively
kv_cache = rnn_prev_hiddens = None
if exists(cache):
kv_cache = cache.next_kv_cache
rnn_prev_hiddens = cache.next_rnn_hiddens
# attend functions for space and time # attend functions for space and time
has_kv_cache = exists(kv_cache) has_kv_cache = exists(kv_cache)
@ -1528,7 +1626,7 @@ class AxialSpaceTimeTransformer(Module):
# prepare cache # prepare cache
time_attn_kv_caches = [] time_attn_kv_caches = []
rnn_hiddens = []
if has_kv_cache: if has_kv_cache:
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:] past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
@ -1543,10 +1641,21 @@ class AxialSpaceTimeTransformer(Module):
iter_kv_cache = iter(kv_cache) iter_kv_cache = iter(kv_cache)
rnn_prev_hiddens = default(rnn_prev_hiddens, (None,))
iter_rnn_prev_hiddens = iter(rnn_prev_hiddens)
# rotary # rotary
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset) rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
# value residual
residual_values = None
if self.value_residual:
residual_values = self.to_value_residual(tokens)
# normed attention inputs # normed attention inputs
normed_time_attn_inputs = [] normed_time_attn_inputs = []
@ -1556,10 +1665,22 @@ class AxialSpaceTimeTransformer(Module):
tokens = self.expand_streams(tokens) tokens = self.expand_streams(tokens)
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time): for (pre_attn_rearrange, post_attn_rearrange, attn, ff), maybe_rnn, layer_is_time in zip(self.layers, self.rnn_layers, self.is_time):
tokens = pre_attn_rearrange(tokens) tokens = pre_attn_rearrange(tokens)
# maybe rnn for time
if layer_is_time and exists(maybe_rnn):
tokens, inverse_pack_batch = pack_one(tokens, '* t d')
tokens, layer_rnn_hiddens = maybe_rnn(tokens, next(iter_rnn_prev_hiddens, None)) # todo, handle rnn cache
tokens = inverse_pack_batch(tokens)
rnn_hiddens.append(layer_rnn_hiddens)
# when is a axial time attention block, should be causal # when is a axial time attention block, should be causal
attend_fn = time_attend if layer_is_time else space_attend attend_fn = time_attend if layer_is_time else space_attend
@ -1570,6 +1691,10 @@ class AxialSpaceTimeTransformer(Module):
maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None
# residual values
layer_residual_values = maybe(pre_attn_rearrange)(residual_values)
# attention layer # attention layer
tokens, attn_intermediates = attn( tokens, attn_intermediates = attn(
@ -1577,6 +1702,7 @@ class AxialSpaceTimeTransformer(Module):
rotary_pos_emb = layer_rotary_pos_emb, rotary_pos_emb = layer_rotary_pos_emb,
attend_fn = attend_fn, attend_fn = attend_fn,
kv_cache = maybe_kv_cache, kv_cache = maybe_kv_cache,
residual_values = layer_residual_values,
return_intermediates = True return_intermediates = True
) )
@ -1611,7 +1737,8 @@ class AxialSpaceTimeTransformer(Module):
intermediates = TransformerIntermediates( intermediates = TransformerIntermediates(
stack(time_attn_kv_caches), stack(time_attn_kv_caches),
safe_stack(normed_time_attn_inputs), safe_stack(normed_time_attn_inputs),
safe_stack(normed_space_attn_inputs) safe_stack(normed_space_attn_inputs),
safe_stack(rnn_hiddens)
) )
return out, intermediates return out, intermediates
@ -1643,11 +1770,6 @@ class VideoTokenizer(Module):
encoder_add_decor_aux_loss = False, encoder_add_decor_aux_loss = False,
decor_auxx_loss_weight = 0.1, decor_auxx_loss_weight = 0.1,
decorr_sample_frac = 0.25, decorr_sample_frac = 0.25,
nd_rotary_kwargs: dict = dict(
rope_min_freq = 1.,
rope_max_freq = 10000.,
rope_p_zero_freqs = 0.
),
num_residual_streams = 1, num_residual_streams = 1,
): ):
super().__init__() super().__init__()
@ -1820,11 +1942,23 @@ class VideoTokenizer(Module):
def forward( def forward(
self, self,
video, # (b c t h w) video_or_image, # (b c t h w) | (b c h w)
return_latents = False, return_latents = False,
mask_patches = None, mask_patches = None,
return_all_losses = False return_intermediates = False,
): ):
# handle image pretraining
is_image = video_or_image.ndim == 4
if is_image:
video = rearrange(video_or_image, 'b c h w -> b c 1 h w')
else:
video = video_or_image
# shapes
batch, _, time, height, width = video.shape batch, _, time, height, width = video.shape
patch_size, device = self.patch_size, video.device patch_size, device = self.patch_size, video.device
@ -1864,7 +1998,7 @@ class VideoTokenizer(Module):
# encoder attention # encoder attention
tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True) tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs, _) = self.encoder_transformer(tokens, return_intermediates = True)
# latent bottleneck # latent bottleneck
@ -1904,12 +2038,21 @@ class VideoTokenizer(Module):
space_decorr_loss * self.decorr_aux_loss_weight space_decorr_loss * self.decorr_aux_loss_weight
) )
if not return_all_losses: if not return_intermediates:
return total_loss return total_loss
losses = (recon_loss, lpips_loss, decorr_loss) losses = TokenizerLosses(recon_loss, lpips_loss, time_decorr_loss, space_decorr_loss)
return total_loss, TokenizerLosses(*losses) out = losses
# handle returning of reconstructed, and image pretraining
if is_image:
recon_video = rearrange(recon_video, 'b c 1 h w -> b c h w')
out = (losses, recon_video)
return total_loss, out
# dynamics model, axial space-time transformer # dynamics model, axial space-time transformer
@ -1931,22 +2074,24 @@ class DynamicsWorldModel(Module):
depth = 4, depth = 4,
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space) pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
time_block_every = 4, # every 4th block is time time_block_every = 4, # every 4th block is time
attn_kwargs: dict = dict( attn_kwargs: dict = dict(),
heads = 8,
),
transformer_kwargs: dict = dict(), transformer_kwargs: dict = dict(),
attn_heads = 8,
attn_dim_head = 64, attn_dim_head = 64,
attn_softclamp_value = 50., attn_softclamp_value = 50.,
ff_kwargs: dict = dict(), ff_kwargs: dict = dict(),
use_time_rnn = True,
loss_weight_fn: Callable = ramp_weight, loss_weight_fn: Callable = ramp_weight,
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
add_reward_embed_to_agent_token = False, add_reward_embed_to_agent_token = False,
add_reward_embed_dropout = 0.1, add_reward_embed_dropout = 0.1,
add_state_pred_head = False,
state_pred_loss_weight = 0.1,
state_entropy_bonus_weight = 0.05,
num_discrete_actions: int | tuple[int, ...] = 0, num_discrete_actions: int | tuple[int, ...] = 0,
num_continuous_actions = 0, num_continuous_actions = 0,
continuous_norm_stats = None, continuous_norm_stats = None,
multi_token_pred_len = 8, multi_token_pred_len = 8, # they do multi-token prediction of 8 steps forward
value_head_mlp_depth = 3, value_head_mlp_depth = 3,
policy_head_mlp_depth = 3, policy_head_mlp_depth = 3,
latent_flow_loss_weight = 1., latent_flow_loss_weight = 1.,
@ -2066,6 +2211,25 @@ class DynamicsWorldModel(Module):
self.pred_orig_latent = pred_orig_latent # x-space or v-space self.pred_orig_latent = pred_orig_latent # x-space or v-space
self.loss_weight_fn = loss_weight_fn self.loss_weight_fn = loss_weight_fn
# state prediction, for state entropy bonus
self.add_state_pred_head = add_state_pred_head
self.state_pred_loss_weight = state_pred_loss_weight
self.should_pred_state = add_state_pred_head and state_pred_loss_weight > 0.
if self.should_pred_state:
self.state_pred_token = nn.Parameter(torch.randn(dim) * 1e-2)
self.to_state_pred = Sequential(
RMSNorm(dim),
nn.Linear(dim, num_latent_tokens * dim_latent * 2),
Rearrange('... (n d two) -> ... n d two', n = num_latent_tokens, two = 2)
)
self.state_entropy_bonus_weight = state_entropy_bonus_weight
self.add_state_entropy_bonus = self.should_pred_state and state_entropy_bonus_weight > 0.
# reinforcement related # reinforcement related
# they sum all the actions into a single token # they sum all the actions into a single token
@ -2147,6 +2311,7 @@ class DynamicsWorldModel(Module):
self.transformer = AxialSpaceTimeTransformer( self.transformer = AxialSpaceTimeTransformer(
dim = dim, dim = dim,
depth = depth, depth = depth,
attn_heads = attn_heads,
attn_dim_head = attn_dim_head, attn_dim_head = attn_dim_head,
attn_softclamp_value = attn_softclamp_value, attn_softclamp_value = attn_softclamp_value,
attn_kwargs = attn_kwargs, attn_kwargs = attn_kwargs,
@ -2155,6 +2320,7 @@ class DynamicsWorldModel(Module):
num_special_spatial_tokens = num_agents, num_special_spatial_tokens = num_agents,
time_block_every = time_block_every, time_block_every = time_block_every,
final_norm = False, final_norm = False,
rnn_time = use_time_rnn,
**transformer_kwargs **transformer_kwargs
) )
@ -2298,10 +2464,10 @@ class DynamicsWorldModel(Module):
env, env,
seed = None, seed = None,
agent_index = 0, agent_index = 0,
step_size = 4, num_steps = 4,
max_timesteps = 16, max_timesteps = 16,
env_is_vectorized = False, env_is_vectorized = False,
use_time_kv_cache = True, use_time_cache = True,
store_agent_embed = True, store_agent_embed = True,
store_old_action_unembeds = True, store_old_action_unembeds = True,
): ):
@ -2338,9 +2504,14 @@ class DynamicsWorldModel(Module):
episode_lens = full((batch,), 0, device = device) episode_lens = full((batch,), 0, device = device)
# derive step size
assert divisible_by(self.max_steps, num_steps)
step_size = self.max_steps // num_steps
# maybe time kv cache # maybe time kv cache
time_kv_cache = None time_cache = None
step_index = 0 step_index = 0
@ -2349,14 +2520,14 @@ class DynamicsWorldModel(Module):
latents = self.video_tokenizer(video, return_latents = True) latents = self.video_tokenizer(video, return_latents = True)
_, (agent_embed, next_time_kv_cache) = self.forward( _, (embeds, next_time_cache) = self.forward(
latents = latents, latents = latents,
signal_levels = self.max_steps - 1, signal_levels = self.max_steps - 1,
step_sizes = step_size, step_sizes = step_size,
rewards = rewards, rewards = rewards,
discrete_actions = discrete_actions, discrete_actions = discrete_actions,
continuous_actions = continuous_actions, continuous_actions = continuous_actions,
time_kv_cache = time_kv_cache, time_cache = time_cache,
latent_is_noised = True, latent_is_noised = True,
return_pred_only = True, return_pred_only = True,
return_intermediates = True return_intermediates = True
@ -2364,11 +2535,13 @@ class DynamicsWorldModel(Module):
# time kv cache # time kv cache
if use_time_kv_cache: if use_time_cache:
time_kv_cache = next_time_kv_cache time_cache = next_time_cache
# get one agent # get one agent
agent_embed = embeds.agent
one_agent_embed = agent_embed[..., -1:, agent_index, :] one_agent_embed = agent_embed[..., -1:, agent_index, :]
# values # values
@ -2423,6 +2596,17 @@ class DynamicsWorldModel(Module):
elif len(env_step_out) == 5: elif len(env_step_out) == 5:
next_frame, reward, terminated, truncated, info = env_step_out next_frame, reward, terminated, truncated, info = env_step_out
# maybe add state entropy bonus
if self.add_state_entropy_bonus:
state_pred_token = embeds.state_pred
state_pred = self.to_state_pred(state_pred_token)
state_pred_log_variance = state_pred[..., 1].sum()
reward = reward + state_pred_log_variance * self.state_entropy_bonus_weight
# update episode lens # update episode lens
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1) episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
@ -2601,7 +2785,7 @@ class DynamicsWorldModel(Module):
): ):
with world_model_forward_context(): with world_model_forward_context():
_, (agent_embeds, _) = self.forward( _, (embeds, _) = self.forward(
latents = latents, latents = latents,
signal_levels = self.max_steps - 1, signal_levels = self.max_steps - 1,
step_sizes = step_size, step_sizes = step_size,
@ -2613,7 +2797,7 @@ class DynamicsWorldModel(Module):
return_intermediates = True return_intermediates = True
) )
agent_embeds = agent_embeds[..., agent_index, :] agent_embeds = embeds.agent[..., agent_index, :]
# maybe detach agent embed # maybe detach agent embed
@ -2757,13 +2941,13 @@ class DynamicsWorldModel(Module):
image_width = None, image_width = None,
return_decoded_video = None, return_decoded_video = None,
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
time_kv_cache: Tensor | None = None, time_cache: Tensor | None = None,
use_time_kv_cache = True, use_time_cache = True,
return_rewards_per_frame = False, return_rewards_per_frame = False,
return_agent_actions = False, return_agent_actions = False,
return_log_probs_and_values = False, return_log_probs_and_values = False,
return_for_policy_optimization = False, return_for_policy_optimization = False,
return_time_kv_cache = False, return_time_cache = False,
store_agent_embed = True, store_agent_embed = True,
store_old_action_unembeds = True store_old_action_unembeds = True
@ -2852,7 +3036,7 @@ class DynamicsWorldModel(Module):
# (2) decoding anything off agent embedding (rewards, actions, etc) # (2) decoding anything off agent embedding (rewards, actions, etc)
take_extra_step = ( take_extra_step = (
use_time_kv_cache or use_time_cache or
return_rewards_per_frame or return_rewards_per_frame or
store_agent_embed or store_agent_embed or
return_agent_actions return_agent_actions
@ -2893,7 +3077,7 @@ class DynamicsWorldModel(Module):
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1) signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
pred, (agent_embed, next_time_kv_cache) = self.forward( pred, (embeds, next_time_cache) = self.forward(
latents = noised_latent_with_context, latents = noised_latent_with_context,
signal_levels = signal_levels_with_context, signal_levels = signal_levels_with_context,
step_sizes = step_size, step_sizes = step_size,
@ -2903,15 +3087,15 @@ class DynamicsWorldModel(Module):
discrete_actions = decoded_discrete_actions, discrete_actions = decoded_discrete_actions,
continuous_actions = decoded_continuous_actions, continuous_actions = decoded_continuous_actions,
proprio = noised_proprio_with_context, proprio = noised_proprio_with_context,
time_kv_cache = time_kv_cache, time_cache = time_cache,
latent_is_noised = True, latent_is_noised = True,
latent_has_view_dim = True, latent_has_view_dim = True,
return_pred_only = True, return_pred_only = True,
return_intermediates = True, return_intermediates = True,
) )
if use_time_kv_cache and is_last_step: if use_time_cache and is_last_step:
time_kv_cache = next_time_kv_cache time_cache = next_time_cache
# early break if taking an extra step for agent embedding off cleaned latents for decoding # early break if taking an extra step for agent embedding off cleaned latents for decoding
@ -2920,8 +3104,10 @@ class DynamicsWorldModel(Module):
# maybe proprio # maybe proprio
if has_proprio: # maybe proprio
pred, pred_proprio = pred
pred_proprio = pred.proprioception
pred = pred.flow
# unpack pred # unpack pred
@ -2958,6 +3144,8 @@ class DynamicsWorldModel(Module):
# take care of the rewards by predicting on the agent token embedding on the last denoising step # take care of the rewards by predicting on the agent token embedding on the last denoising step
if return_rewards_per_frame: if return_rewards_per_frame:
agent_embed = embeds.agent
one_agent_embed = agent_embed[:, -1:, agent_index] one_agent_embed = agent_embed[:, -1:, agent_index]
reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0) reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0)
@ -2968,6 +3156,8 @@ class DynamicsWorldModel(Module):
# maybe store agent embed # maybe store agent embed
if store_agent_embed: if store_agent_embed:
agent_embed = embeds.agent
one_agent_embed = agent_embed[:, -1:, agent_index] one_agent_embed = agent_embed[:, -1:, agent_index]
acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1) acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
@ -3060,10 +3250,10 @@ class DynamicsWorldModel(Module):
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio): if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
out = video if return_decoded_video else latents out = video if return_decoded_video else latents
if not return_time_kv_cache: if not return_time_cache:
return out return out
return out, time_kv_cache return out, time_cache
# returning agent actions, rewards, and log probs + values for policy optimization # returning agent actions, rewards, and log probs + values for policy optimization
@ -3093,10 +3283,10 @@ class DynamicsWorldModel(Module):
gen.values = decoded_values gen.values = decoded_values
if not return_time_kv_cache: if not return_time_cache:
return gen return gen
return gen, time_kv_cache return gen, time_cache
def forward( def forward(
self, self,
@ -3115,7 +3305,7 @@ class DynamicsWorldModel(Module):
discrete_action_types = None, # (na) discrete_action_types = None, # (na)
continuous_action_types = None, # (na) continuous_action_types = None, # (na)
proprio = None, # (b t dp) proprio = None, # (b t dp)
time_kv_cache = None, time_cache = None,
return_pred_only = False, return_pred_only = False,
latent_is_noised = False, latent_is_noised = False,
return_all_losses = False, return_all_losses = False,
@ -3265,9 +3455,13 @@ class DynamicsWorldModel(Module):
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time) agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
# empty token
empty_token = agent_tokens[:, :, 0:0]
# maybe reward tokens # maybe reward tokens
reward_tokens = agent_tokens[:, :, 0:0] reward_tokens = empty_token
if exists(rewards): if exists(rewards):
two_hot_encoding = self.reward_encoder(rewards) two_hot_encoding = self.reward_encoder(rewards)
@ -3307,6 +3501,13 @@ class DynamicsWorldModel(Module):
else: else:
noised_proprio = proprio noised_proprio = proprio
# maybe state prediction token
if self.should_pred_state:
state_pred_token = repeat(self.state_pred_token, 'd -> b t 1 d', b = batch, t = time)
else:
state_pred_token = empty_token
# maybe create the action tokens # maybe create the action tokens
if exists(discrete_actions) or exists(continuous_actions): if exists(discrete_actions) or exists(continuous_actions):
@ -3331,11 +3532,11 @@ class DynamicsWorldModel(Module):
action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1]) action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1])
else: else:
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens action_tokens = empty_token # else empty off agent tokens
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss # main function, needs to be defined as such for shortcut training - additional calls for consistency loss
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False): def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, state_pred_token, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False):
# latents to spatial tokens # latents to spatial tokens
@ -3384,15 +3585,15 @@ class DynamicsWorldModel(Module):
# pack to tokens for attending # pack to tokens for attending
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d') tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, state_pred_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
# attention # attention
tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True) tokens, intermediates = self.transformer(tokens, cache = time_cache, return_intermediates = True)
# unpack # unpack
flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d') flow_token, space_tokens, proprio_token, state_pred_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
# pooling # pooling
@ -3404,42 +3605,54 @@ class DynamicsWorldModel(Module):
if self.has_proprio: if self.has_proprio:
pred_proprio = self.to_proprio_pred(proprio_token) pred_proprio = self.to_proprio_pred(proprio_token)
else:
pred_proprio = None
pred = (pred, pred_proprio) # maybe state pred
if self.should_pred_state:
pred_state = self.to_state_pred(state_pred_token)
else:
pred_state = None
# returning # returning
predictions = Predictions(pred, pred_proprio, pred_state)
embeds = Embeds(agent_tokens, state_pred_token)
if not return_agent_tokens: if not return_agent_tokens:
return pred return predictions
if not return_time_kv_cache: if not return_time_cache:
return pred, agent_tokens return predictions, embeds
return pred, (agent_tokens, next_time_kv_cache) return predictions, (embeds, intermediates)
# curry into get_prediction what does not change during first call as well as the shortcut ones # curry into get_prediction what does not change during first call as well as the shortcut ones
_get_prediction = partial(get_prediction, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens) _get_prediction = partial(get_prediction, state_pred_token = state_pred_token, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
# forward the network # forward the network
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True) pred, (embeds, intermediates) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_cache = True)
if return_pred_only: if return_pred_only:
if not return_intermediates: if not return_intermediates:
return pred return pred
return pred, (encoded_agent_tokens, next_time_kv_cache) return pred, (embeds, intermediates)
# pack the predictions to calculate flow for different modalities all at once # pack the predictions to calculate flow for different modalities all at once
if self.has_proprio: if self.has_proprio:
pred, for_flow_loss_packed_shape = pack(pred, 'b t *') packed_pred, for_flow_loss_packed_shape = pack((pred.flow, pred.proprioception), 'b t *')
noised, _ = pack((noised_latents, noised_proprio), 'b t *') noised, _ = pack((noised_latents, noised_proprio), 'b t *')
data, _ = pack((latents, proprio), 'b t *') data, _ = pack((latents, proprio), 'b t *')
noise, _ = pack((noise, proprio_noise), 'b t *') noise, _ = pack((noise, proprio_noise), 'b t *')
else: else:
packed_pred = pred.flow
noised = noised_latents noised = noised_latents
data = latents data = latents
@ -3458,9 +3671,10 @@ class DynamicsWorldModel(Module):
pred = fn(noised, noised_proprio, *args, **kwargs) pred = fn(noised, noised_proprio, *args, **kwargs)
if self.has_proprio: if self.has_proprio:
pred, _ = pack(pred, 'b t *') packed_flow, _ = pack((pred.flow, pred.proprioception), 'b t *')
return packed_flow
return pred return pred.flow
return inner return inner
wrapped_get_prediction = maybe_pack_unpack(_get_prediction) wrapped_get_prediction = maybe_pack_unpack(_get_prediction)
@ -3527,12 +3741,12 @@ class DynamicsWorldModel(Module):
# need to convert x-space to v-space # need to convert x-space to v-space
if is_x_space: if is_x_space:
pred = (pred - noised) / (1. - first_times) packed_pred = (packed_pred - noised) / (1. - first_times)
maybe_shortcut_loss_weight = (1. - first_times) ** 2 maybe_shortcut_loss_weight = (1. - first_times) ** 2
# mse loss # mse loss
flow_losses = F.mse_loss(pred, pred_target, reduction = 'none') flow_losses = F.mse_loss(packed_pred, pred_target, reduction = 'none')
flow_losses = flow_losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7) flow_losses = flow_losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
@ -3564,6 +3778,8 @@ class DynamicsWorldModel(Module):
if exists(rewards): if exists(rewards):
encoded_agent_tokens = embeds.agent
if rewards.ndim == 2: # (b t) if rewards.ndim == 2: # (b t)
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean') encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
@ -3584,6 +3800,18 @@ class DynamicsWorldModel(Module):
else: else:
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9) reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
# maybe autoregressive state prediction loss
state_pred_loss = self.zero
if self.should_pred_state:
pred_latent, latent_to_pred = pred.state[:, :-1], latents[:, 1:]
pred_latent_mean, pred_latent_log_var = pred_latent.unbind(dim = -1)
pred_latent_var = pred_latent_log_var.exp()
state_pred_loss = F.gaussian_nll_loss(pred_latent_mean, latent_to_pred, var = pred_latent_var)
# maybe autoregressive action loss # maybe autoregressive action loss
discrete_action_loss = self.zero discrete_action_loss = self.zero
@ -3651,7 +3879,7 @@ class DynamicsWorldModel(Module):
# handle loss normalization # handle loss normalization
losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss) losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss, state_pred_loss)
if exists(self.flow_loss_normalizer): if exists(self.flow_loss_normalizer):
flow_loss = self.flow_loss_normalizer(flow_loss, update_ema = update_loss_ema) flow_loss = self.flow_loss_normalizer(flow_loss, update_ema = update_loss_ema)
@ -3671,7 +3899,8 @@ class DynamicsWorldModel(Module):
flow_loss * self.latent_flow_loss_weight + flow_loss * self.latent_flow_loss_weight +
(reward_loss * self.reward_loss_weight).sum() + (reward_loss * self.reward_loss_weight).sum() +
(discrete_action_loss * self.discrete_action_loss_weight).sum() + (discrete_action_loss * self.discrete_action_loss_weight).sum() +
(continuous_action_loss * self.continuous_action_loss_weight).sum() (continuous_action_loss * self.continuous_action_loss_weight).sum() +
(state_pred_loss * self.state_pred_loss_weight)
) )
if not return_all_losses: if not return_all_losses:

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.1.12" version = "0.1.24"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -15,7 +15,8 @@ def exists(v):
@param('condition_on_actions', (False, True)) @param('condition_on_actions', (False, True))
@param('num_residual_streams', (1, 4)) @param('num_residual_streams', (1, 4))
@param('add_reward_embed_to_agent_token', (False, True)) @param('add_reward_embed_to_agent_token', (False, True))
@param('use_time_kv_cache', (False, True)) @param('add_state_pred_head', (False, True))
@param('use_time_cache', (False, True))
@param('var_len', (False, True)) @param('var_len', (False, True))
def test_e2e( def test_e2e(
pred_orig_latent, pred_orig_latent,
@ -28,7 +29,8 @@ def test_e2e(
condition_on_actions, condition_on_actions,
num_residual_streams, num_residual_streams,
add_reward_embed_to_agent_token, add_reward_embed_to_agent_token,
use_time_kv_cache, add_state_pred_head,
use_time_cache,
var_len var_len
): ):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
@ -71,12 +73,13 @@ def test_e2e(
pred_orig_latent = pred_orig_latent, pred_orig_latent = pred_orig_latent,
num_discrete_actions = 4, num_discrete_actions = 4,
attn_dim_head = 16, attn_dim_head = 16,
attn_heads = heads,
attn_kwargs = dict( attn_kwargs = dict(
heads = heads,
query_heads = query_heads, query_heads = query_heads,
), ),
prob_no_shortcut_train = prob_no_shortcut_train, prob_no_shortcut_train = prob_no_shortcut_train,
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token, add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
add_state_pred_head = add_state_pred_head,
num_residual_streams = num_residual_streams num_residual_streams = num_residual_streams
) )
@ -123,7 +126,7 @@ def test_e2e(
image_width = 128, image_width = 128,
batch_size = 2, batch_size = 2,
return_rewards_per_frame = True, return_rewards_per_frame = True,
use_time_kv_cache = use_time_kv_cache use_time_cache = use_time_cache
) )
assert generations.video.shape == (2, 3, 10, 128, 128) assert generations.video.shape == (2, 3, 10, 128, 128)
@ -617,9 +620,9 @@ def test_cache_generate():
num_residual_streams = 1 num_residual_streams = 1
) )
generated, time_kv_cache = dynamics.generate(1, return_time_kv_cache = True) generated, time_cache = dynamics.generate(1, return_time_cache = True)
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
@param('vectorized', (False, True)) @param('vectorized', (False, True))
@param('use_pmpo', (False, True)) @param('use_pmpo', (False, True))
@ -643,7 +646,9 @@ def test_online_rl(
dim_latent = 16, dim_latent = 16,
patch_size = 32, patch_size = 32,
attn_dim_head = 16, attn_dim_head = 16,
num_latent_tokens = 1 num_latent_tokens = 1,
image_height = 256,
image_width = 256,
) )
world_model_and_policy = DynamicsWorldModel( world_model_and_policy = DynamicsWorldModel(
@ -677,10 +682,12 @@ def test_online_rl(
# manually # manually
dream_experience = world_model_and_policy.generate(10, batch_size = 1, store_agent_embed = store_agent_embed, return_for_policy_optimization = True)
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed) one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed) another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
combined_experience = combine_experiences([one_experience, another_experience]) combined_experience = combine_experiences([dream_experience, one_experience, another_experience])
# quick test moving the experience to different devices # quick test moving the experience to different devices
@ -803,3 +810,22 @@ def test_epo():
fitness = torch.randn(16,) fitness = torch.randn(16,)
dynamics.evolve_(fitness) dynamics.evolve_(fitness)
def test_images_to_video_tokenizer():
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel, AxialSpaceTimeTransformer
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256,
encoder_add_decor_aux_loss = True
)
images = torch.randn(2, 3, 256, 256)
loss, (losses, recon_images) = tokenizer(images, return_intermediates = True)
loss.backward()
assert images.shape == recon_images.shape