Compare commits
No commits in common. "main" and "0.1.12" have entirely different histories.
@ -1,7 +1,6 @@
|
||||
from dreamer4.dreamer4 import (
|
||||
VideoTokenizer,
|
||||
DynamicsWorldModel,
|
||||
AxialSpaceTimeTransformer
|
||||
DynamicsWorldModel
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
from typing import Callable
|
||||
|
||||
import math
|
||||
from math import ceil, log2
|
||||
@ -71,19 +70,13 @@ except ImportError:
|
||||
|
||||
LinearNoBias = partial(Linear, bias = False)
|
||||
|
||||
VideoTokenizerIntermediates = namedtuple('VideoTokenizerIntermediates', ('losses', 'recon'))
|
||||
|
||||
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
|
||||
|
||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions', 'state_pred'))
|
||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
||||
|
||||
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
|
||||
|
||||
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs', 'next_rnn_hiddens'))
|
||||
|
||||
Predictions = namedtuple('Predictions', ('flow', 'proprioception', 'state'))
|
||||
|
||||
Embeds = namedtuple('Embeds', ['agent', 'state_pred'])
|
||||
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs'))
|
||||
|
||||
MaybeTensor = Tensor | None
|
||||
|
||||
@ -102,7 +95,7 @@ class Experience:
|
||||
lens: MaybeTensor = None
|
||||
is_truncated: MaybeTensor = None
|
||||
agent_index: int = 0
|
||||
is_from_world_model: bool | Tensor = True
|
||||
is_from_world_model: bool = True
|
||||
|
||||
def cpu(self):
|
||||
return self.to(torch.device('cpu'))
|
||||
@ -130,9 +123,6 @@ def combine_experiences(
|
||||
if not exists(exp.is_truncated):
|
||||
exp.is_truncated = full((batch,), True, device = device)
|
||||
|
||||
if isinstance(exp.is_from_world_model, bool):
|
||||
exp.is_from_world_model = tensor(exp.is_from_world_model)
|
||||
|
||||
# convert to dictionary
|
||||
|
||||
exps_dict = [asdict(exp) for exp in exps]
|
||||
@ -154,15 +144,11 @@ def combine_experiences(
|
||||
|
||||
for field_values in all_field_values:
|
||||
|
||||
first_value = first(field_values)
|
||||
|
||||
if is_tensor(first_value):
|
||||
if is_tensor(first(field_values)):
|
||||
|
||||
field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2))
|
||||
|
||||
cat_or_stack = cat if first_value.ndim > 0 else stack
|
||||
|
||||
new_field_value = cat_or_stack(field_values)
|
||||
new_field_value = cat(field_values)
|
||||
else:
|
||||
new_field_value = first(list(set(field_values)))
|
||||
|
||||
@ -203,13 +189,6 @@ def sample_prob(prob):
|
||||
def is_power_two(num):
|
||||
return log2(num).is_integer()
|
||||
|
||||
def maybe(fn):
|
||||
def inner(t, *args, **kwargs):
|
||||
if not exists(t) or not exists(fn):
|
||||
return None
|
||||
return fn(t)
|
||||
return inner
|
||||
|
||||
# tensor helpers
|
||||
|
||||
def is_empty(t):
|
||||
@ -1305,8 +1284,7 @@ class Attention(Module):
|
||||
pre_rmsnorm = True,
|
||||
gate_values = True,
|
||||
rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2
|
||||
rmsnorm_key = True,
|
||||
value_residual = True
|
||||
rmsnorm_key = True
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
||||
@ -1345,14 +1323,6 @@ class Attention(Module):
|
||||
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity()
|
||||
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
|
||||
|
||||
# value residual
|
||||
|
||||
self.to_learned_value_residual_mix = nn.Sequential(
|
||||
nn.Linear(dim, heads),
|
||||
Rearrange('b n h -> b h n 1'),
|
||||
nn.Sigmoid()
|
||||
) if value_residual else None
|
||||
|
||||
def muon_parameters(self):
|
||||
# omit the queries and keys for now given what we learned from kimi 2 paper
|
||||
|
||||
@ -1367,7 +1337,6 @@ class Attention(Module):
|
||||
kv_cache = None,
|
||||
return_intermediates = False,
|
||||
rotary_pos_emb = None,
|
||||
residual_values = None, # (b n h d)
|
||||
attend_fn: Callable | None = None
|
||||
):
|
||||
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
||||
@ -1380,17 +1349,6 @@ class Attention(Module):
|
||||
|
||||
q, k, v = map(self.split_heads, (q, k, v))
|
||||
|
||||
# handle maybe value residual
|
||||
|
||||
if exists(residual_values):
|
||||
residual_values = rearrange(residual_values, '... n h d -> (...) h n d')
|
||||
|
||||
assert exists(self.to_learned_value_residual_mix)
|
||||
|
||||
learned_mix = self.to_learned_value_residual_mix(tokens)
|
||||
|
||||
v = v.lerp(residual_values, learned_mix)
|
||||
|
||||
# qk rmsnorm
|
||||
|
||||
q = self.q_heads_rmsnorm(q)
|
||||
@ -1467,29 +1425,6 @@ class SwiGLUFeedforward(Module):
|
||||
|
||||
return self.proj_out(x)
|
||||
|
||||
# rnn
|
||||
|
||||
class GRULayer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = nn.RMSNorm(dim)
|
||||
self.gru = nn.GRU(dim, dim_out, batch_first = True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
prev_hiddens = None
|
||||
):
|
||||
x = self.norm(x)
|
||||
|
||||
x, hiddens = self.gru(x, prev_hiddens)
|
||||
|
||||
return x, hiddens
|
||||
|
||||
# axial space time transformer
|
||||
|
||||
class AxialSpaceTimeTransformer(Module):
|
||||
@ -1497,7 +1432,6 @@ class AxialSpaceTimeTransformer(Module):
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
attn_heads = 8,
|
||||
attn_dim_head = 64,
|
||||
attn_softclamp_value = 50.,
|
||||
time_block_every = 4,
|
||||
@ -1506,9 +1440,7 @@ class AxialSpaceTimeTransformer(Module):
|
||||
num_residual_streams = 1,
|
||||
num_special_spatial_tokens = 1,
|
||||
special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
|
||||
final_norm = True,
|
||||
value_residual = True, # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS
|
||||
rnn_time = True
|
||||
final_norm = True
|
||||
):
|
||||
super().__init__()
|
||||
assert depth >= time_block_every, f'depth must be at least {time_block_every}'
|
||||
@ -1529,24 +1461,6 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
self.time_rotary = Rotary1D(attn_dim_head)
|
||||
|
||||
# project initial for value residuals
|
||||
|
||||
self.value_residual = value_residual
|
||||
|
||||
if value_residual:
|
||||
dim_inner = attn_dim_head * attn_heads
|
||||
|
||||
self.to_value_residual = nn.Sequential(
|
||||
nn.RMSNorm(dim),
|
||||
nn.Linear(dim, dim_inner, bias = False),
|
||||
Rearrange('... (h d) -> ... h d', h = attn_heads)
|
||||
)
|
||||
|
||||
# a gru layer across time
|
||||
|
||||
self.rnn_time = rnn_time
|
||||
rnn_layers = []
|
||||
|
||||
# transformer
|
||||
|
||||
layers = []
|
||||
@ -1558,21 +1472,17 @@ class AxialSpaceTimeTransformer(Module):
|
||||
is_time_block = divisible_by(layer_index, time_block_every)
|
||||
is_time.append(is_time_block)
|
||||
|
||||
rearrange_to_attend = Rearrange('b t s ... -> b s t ...') if is_time_block else Identity()
|
||||
rearrange_from_attend = Rearrange('b s t ... -> b t s ...') if is_time_block else Identity()
|
||||
rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity()
|
||||
rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity()
|
||||
|
||||
layers.append(ModuleList([
|
||||
rearrange_to_attend,
|
||||
rearrange_from_attend,
|
||||
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, value_residual = value_residual, **attn_kwargs)),
|
||||
hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
|
||||
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
||||
]))
|
||||
|
||||
rnn_layers.append(hyper_conn(branch = GRULayer(dim, dim)) if is_time_block and rnn_time else None)
|
||||
|
||||
self.layers = ModuleList(layers)
|
||||
self.rnn_layers = ModuleList(rnn_layers)
|
||||
|
||||
self.is_time = is_time
|
||||
|
||||
# final norm
|
||||
@ -1594,8 +1504,8 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens, # (b t s d)
|
||||
cache: TransformerIntermediates | None = None,
|
||||
tokens, # (b t s d)
|
||||
kv_cache: Tensor | None = None, # (y 2 b h t d)
|
||||
return_intermediates = False
|
||||
|
||||
): # (b t s d) | (y 2 b h t d)
|
||||
@ -1604,14 +1514,6 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
assert tokens.ndim == 4
|
||||
|
||||
# destruct intermediates to cache for attention and rnn respectively
|
||||
|
||||
kv_cache = rnn_prev_hiddens = None
|
||||
|
||||
if exists(cache):
|
||||
kv_cache = cache.next_kv_cache
|
||||
rnn_prev_hiddens = cache.next_rnn_hiddens
|
||||
|
||||
# attend functions for space and time
|
||||
|
||||
has_kv_cache = exists(kv_cache)
|
||||
@ -1626,7 +1528,7 @@ class AxialSpaceTimeTransformer(Module):
|
||||
# prepare cache
|
||||
|
||||
time_attn_kv_caches = []
|
||||
rnn_hiddens = []
|
||||
|
||||
|
||||
if has_kv_cache:
|
||||
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
||||
@ -1641,21 +1543,10 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
iter_kv_cache = iter(kv_cache)
|
||||
|
||||
rnn_prev_hiddens = default(rnn_prev_hiddens, (None,))
|
||||
|
||||
iter_rnn_prev_hiddens = iter(rnn_prev_hiddens)
|
||||
|
||||
# rotary
|
||||
|
||||
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
||||
|
||||
# value residual
|
||||
|
||||
residual_values = None
|
||||
|
||||
if self.value_residual:
|
||||
residual_values = self.to_value_residual(tokens)
|
||||
|
||||
# normed attention inputs
|
||||
|
||||
normed_time_attn_inputs = []
|
||||
@ -1665,22 +1556,10 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
tokens = self.expand_streams(tokens)
|
||||
|
||||
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), maybe_rnn, layer_is_time in zip(self.layers, self.rnn_layers, self.is_time):
|
||||
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
||||
|
||||
tokens = pre_attn_rearrange(tokens)
|
||||
|
||||
# maybe rnn for time
|
||||
|
||||
if layer_is_time and exists(maybe_rnn):
|
||||
|
||||
tokens, inverse_pack_batch = pack_one(tokens, '* t d')
|
||||
|
||||
tokens, layer_rnn_hiddens = maybe_rnn(tokens, next(iter_rnn_prev_hiddens, None)) # todo, handle rnn cache
|
||||
|
||||
tokens = inverse_pack_batch(tokens)
|
||||
|
||||
rnn_hiddens.append(layer_rnn_hiddens)
|
||||
|
||||
# when is a axial time attention block, should be causal
|
||||
|
||||
attend_fn = time_attend if layer_is_time else space_attend
|
||||
@ -1691,10 +1570,6 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None
|
||||
|
||||
# residual values
|
||||
|
||||
layer_residual_values = maybe(pre_attn_rearrange)(residual_values)
|
||||
|
||||
# attention layer
|
||||
|
||||
tokens, attn_intermediates = attn(
|
||||
@ -1702,7 +1577,6 @@ class AxialSpaceTimeTransformer(Module):
|
||||
rotary_pos_emb = layer_rotary_pos_emb,
|
||||
attend_fn = attend_fn,
|
||||
kv_cache = maybe_kv_cache,
|
||||
residual_values = layer_residual_values,
|
||||
return_intermediates = True
|
||||
)
|
||||
|
||||
@ -1737,8 +1611,7 @@ class AxialSpaceTimeTransformer(Module):
|
||||
intermediates = TransformerIntermediates(
|
||||
stack(time_attn_kv_caches),
|
||||
safe_stack(normed_time_attn_inputs),
|
||||
safe_stack(normed_space_attn_inputs),
|
||||
safe_stack(rnn_hiddens)
|
||||
safe_stack(normed_space_attn_inputs)
|
||||
)
|
||||
|
||||
return out, intermediates
|
||||
@ -1770,6 +1643,11 @@ class VideoTokenizer(Module):
|
||||
encoder_add_decor_aux_loss = False,
|
||||
decor_auxx_loss_weight = 0.1,
|
||||
decorr_sample_frac = 0.25,
|
||||
nd_rotary_kwargs: dict = dict(
|
||||
rope_min_freq = 1.,
|
||||
rope_max_freq = 10000.,
|
||||
rope_p_zero_freqs = 0.
|
||||
),
|
||||
num_residual_streams = 1,
|
||||
):
|
||||
super().__init__()
|
||||
@ -1942,23 +1820,11 @@ class VideoTokenizer(Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_or_image, # (b c t h w) | (b c h w)
|
||||
video, # (b c t h w)
|
||||
return_latents = False,
|
||||
mask_patches = None,
|
||||
return_intermediates = False,
|
||||
return_all_losses = False
|
||||
):
|
||||
|
||||
# handle image pretraining
|
||||
|
||||
is_image = video_or_image.ndim == 4
|
||||
|
||||
if is_image:
|
||||
video = rearrange(video_or_image, 'b c h w -> b c 1 h w')
|
||||
else:
|
||||
video = video_or_image
|
||||
|
||||
# shapes
|
||||
|
||||
batch, _, time, height, width = video.shape
|
||||
patch_size, device = self.patch_size, video.device
|
||||
|
||||
@ -1998,7 +1864,7 @@ class VideoTokenizer(Module):
|
||||
|
||||
# encoder attention
|
||||
|
||||
tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs, _) = self.encoder_transformer(tokens, return_intermediates = True)
|
||||
tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
|
||||
|
||||
# latent bottleneck
|
||||
|
||||
@ -2038,21 +1904,12 @@ class VideoTokenizer(Module):
|
||||
space_decorr_loss * self.decorr_aux_loss_weight
|
||||
)
|
||||
|
||||
if not return_intermediates:
|
||||
if not return_all_losses:
|
||||
return total_loss
|
||||
|
||||
losses = TokenizerLosses(recon_loss, lpips_loss, time_decorr_loss, space_decorr_loss)
|
||||
losses = (recon_loss, lpips_loss, decorr_loss)
|
||||
|
||||
out = losses
|
||||
|
||||
# handle returning of reconstructed, and image pretraining
|
||||
|
||||
if is_image:
|
||||
recon_video = rearrange(recon_video, 'b c 1 h w -> b c h w')
|
||||
|
||||
out = (losses, recon_video)
|
||||
|
||||
return total_loss, out
|
||||
return total_loss, TokenizerLosses(*losses)
|
||||
|
||||
# dynamics model, axial space-time transformer
|
||||
|
||||
@ -2074,24 +1931,22 @@ class DynamicsWorldModel(Module):
|
||||
depth = 4,
|
||||
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
||||
time_block_every = 4, # every 4th block is time
|
||||
attn_kwargs: dict = dict(),
|
||||
attn_kwargs: dict = dict(
|
||||
heads = 8,
|
||||
),
|
||||
transformer_kwargs: dict = dict(),
|
||||
attn_heads = 8,
|
||||
attn_dim_head = 64,
|
||||
attn_softclamp_value = 50.,
|
||||
ff_kwargs: dict = dict(),
|
||||
use_time_rnn = True,
|
||||
loss_weight_fn: Callable = ramp_weight,
|
||||
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
||||
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
||||
add_reward_embed_to_agent_token = False,
|
||||
add_reward_embed_dropout = 0.1,
|
||||
add_state_pred_head = False,
|
||||
state_pred_loss_weight = 0.1,
|
||||
state_entropy_bonus_weight = 0.05,
|
||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||
num_continuous_actions = 0,
|
||||
continuous_norm_stats = None,
|
||||
multi_token_pred_len = 8, # they do multi-token prediction of 8 steps forward
|
||||
multi_token_pred_len = 8,
|
||||
value_head_mlp_depth = 3,
|
||||
policy_head_mlp_depth = 3,
|
||||
latent_flow_loss_weight = 1.,
|
||||
@ -2211,25 +2066,6 @@ class DynamicsWorldModel(Module):
|
||||
self.pred_orig_latent = pred_orig_latent # x-space or v-space
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
|
||||
# state prediction, for state entropy bonus
|
||||
|
||||
self.add_state_pred_head = add_state_pred_head
|
||||
self.state_pred_loss_weight = state_pred_loss_weight
|
||||
|
||||
self.should_pred_state = add_state_pred_head and state_pred_loss_weight > 0.
|
||||
|
||||
if self.should_pred_state:
|
||||
self.state_pred_token = nn.Parameter(torch.randn(dim) * 1e-2)
|
||||
|
||||
self.to_state_pred = Sequential(
|
||||
RMSNorm(dim),
|
||||
nn.Linear(dim, num_latent_tokens * dim_latent * 2),
|
||||
Rearrange('... (n d two) -> ... n d two', n = num_latent_tokens, two = 2)
|
||||
)
|
||||
|
||||
self.state_entropy_bonus_weight = state_entropy_bonus_weight
|
||||
self.add_state_entropy_bonus = self.should_pred_state and state_entropy_bonus_weight > 0.
|
||||
|
||||
# reinforcement related
|
||||
|
||||
# they sum all the actions into a single token
|
||||
@ -2311,7 +2147,6 @@ class DynamicsWorldModel(Module):
|
||||
self.transformer = AxialSpaceTimeTransformer(
|
||||
dim = dim,
|
||||
depth = depth,
|
||||
attn_heads = attn_heads,
|
||||
attn_dim_head = attn_dim_head,
|
||||
attn_softclamp_value = attn_softclamp_value,
|
||||
attn_kwargs = attn_kwargs,
|
||||
@ -2320,7 +2155,6 @@ class DynamicsWorldModel(Module):
|
||||
num_special_spatial_tokens = num_agents,
|
||||
time_block_every = time_block_every,
|
||||
final_norm = False,
|
||||
rnn_time = use_time_rnn,
|
||||
**transformer_kwargs
|
||||
)
|
||||
|
||||
@ -2464,10 +2298,10 @@ class DynamicsWorldModel(Module):
|
||||
env,
|
||||
seed = None,
|
||||
agent_index = 0,
|
||||
num_steps = 4,
|
||||
step_size = 4,
|
||||
max_timesteps = 16,
|
||||
env_is_vectorized = False,
|
||||
use_time_cache = True,
|
||||
use_time_kv_cache = True,
|
||||
store_agent_embed = True,
|
||||
store_old_action_unembeds = True,
|
||||
):
|
||||
@ -2504,14 +2338,9 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
episode_lens = full((batch,), 0, device = device)
|
||||
|
||||
# derive step size
|
||||
|
||||
assert divisible_by(self.max_steps, num_steps)
|
||||
step_size = self.max_steps // num_steps
|
||||
|
||||
# maybe time kv cache
|
||||
|
||||
time_cache = None
|
||||
time_kv_cache = None
|
||||
|
||||
step_index = 0
|
||||
|
||||
@ -2520,14 +2349,14 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
latents = self.video_tokenizer(video, return_latents = True)
|
||||
|
||||
_, (embeds, next_time_cache) = self.forward(
|
||||
_, (agent_embed, next_time_kv_cache) = self.forward(
|
||||
latents = latents,
|
||||
signal_levels = self.max_steps - 1,
|
||||
step_sizes = step_size,
|
||||
rewards = rewards,
|
||||
discrete_actions = discrete_actions,
|
||||
continuous_actions = continuous_actions,
|
||||
time_cache = time_cache,
|
||||
time_kv_cache = time_kv_cache,
|
||||
latent_is_noised = True,
|
||||
return_pred_only = True,
|
||||
return_intermediates = True
|
||||
@ -2535,13 +2364,11 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# time kv cache
|
||||
|
||||
if use_time_cache:
|
||||
time_cache = next_time_cache
|
||||
if use_time_kv_cache:
|
||||
time_kv_cache = next_time_kv_cache
|
||||
|
||||
# get one agent
|
||||
|
||||
agent_embed = embeds.agent
|
||||
|
||||
one_agent_embed = agent_embed[..., -1:, agent_index, :]
|
||||
|
||||
# values
|
||||
@ -2596,17 +2423,6 @@ class DynamicsWorldModel(Module):
|
||||
elif len(env_step_out) == 5:
|
||||
next_frame, reward, terminated, truncated, info = env_step_out
|
||||
|
||||
# maybe add state entropy bonus
|
||||
|
||||
if self.add_state_entropy_bonus:
|
||||
state_pred_token = embeds.state_pred
|
||||
|
||||
state_pred = self.to_state_pred(state_pred_token)
|
||||
|
||||
state_pred_log_variance = state_pred[..., 1].sum()
|
||||
|
||||
reward = reward + state_pred_log_variance * self.state_entropy_bonus_weight
|
||||
|
||||
# update episode lens
|
||||
|
||||
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
||||
@ -2785,7 +2601,7 @@ class DynamicsWorldModel(Module):
|
||||
):
|
||||
|
||||
with world_model_forward_context():
|
||||
_, (embeds, _) = self.forward(
|
||||
_, (agent_embeds, _) = self.forward(
|
||||
latents = latents,
|
||||
signal_levels = self.max_steps - 1,
|
||||
step_sizes = step_size,
|
||||
@ -2797,7 +2613,7 @@ class DynamicsWorldModel(Module):
|
||||
return_intermediates = True
|
||||
)
|
||||
|
||||
agent_embeds = embeds.agent[..., agent_index, :]
|
||||
agent_embeds = agent_embeds[..., agent_index, :]
|
||||
|
||||
# maybe detach agent embed
|
||||
|
||||
@ -2941,13 +2757,13 @@ class DynamicsWorldModel(Module):
|
||||
image_width = None,
|
||||
return_decoded_video = None,
|
||||
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
||||
time_cache: Tensor | None = None,
|
||||
use_time_cache = True,
|
||||
time_kv_cache: Tensor | None = None,
|
||||
use_time_kv_cache = True,
|
||||
return_rewards_per_frame = False,
|
||||
return_agent_actions = False,
|
||||
return_log_probs_and_values = False,
|
||||
return_for_policy_optimization = False,
|
||||
return_time_cache = False,
|
||||
return_time_kv_cache = False,
|
||||
store_agent_embed = True,
|
||||
store_old_action_unembeds = True
|
||||
|
||||
@ -3036,7 +2852,7 @@ class DynamicsWorldModel(Module):
|
||||
# (2) decoding anything off agent embedding (rewards, actions, etc)
|
||||
|
||||
take_extra_step = (
|
||||
use_time_cache or
|
||||
use_time_kv_cache or
|
||||
return_rewards_per_frame or
|
||||
store_agent_embed or
|
||||
return_agent_actions
|
||||
@ -3077,7 +2893,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
||||
|
||||
pred, (embeds, next_time_cache) = self.forward(
|
||||
pred, (agent_embed, next_time_kv_cache) = self.forward(
|
||||
latents = noised_latent_with_context,
|
||||
signal_levels = signal_levels_with_context,
|
||||
step_sizes = step_size,
|
||||
@ -3087,15 +2903,15 @@ class DynamicsWorldModel(Module):
|
||||
discrete_actions = decoded_discrete_actions,
|
||||
continuous_actions = decoded_continuous_actions,
|
||||
proprio = noised_proprio_with_context,
|
||||
time_cache = time_cache,
|
||||
time_kv_cache = time_kv_cache,
|
||||
latent_is_noised = True,
|
||||
latent_has_view_dim = True,
|
||||
return_pred_only = True,
|
||||
return_intermediates = True,
|
||||
)
|
||||
|
||||
if use_time_cache and is_last_step:
|
||||
time_cache = next_time_cache
|
||||
if use_time_kv_cache and is_last_step:
|
||||
time_kv_cache = next_time_kv_cache
|
||||
|
||||
# early break if taking an extra step for agent embedding off cleaned latents for decoding
|
||||
|
||||
@ -3104,10 +2920,8 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# maybe proprio
|
||||
|
||||
# maybe proprio
|
||||
|
||||
pred_proprio = pred.proprioception
|
||||
pred = pred.flow
|
||||
if has_proprio:
|
||||
pred, pred_proprio = pred
|
||||
|
||||
# unpack pred
|
||||
|
||||
@ -3144,8 +2958,6 @@ class DynamicsWorldModel(Module):
|
||||
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
||||
|
||||
if return_rewards_per_frame:
|
||||
agent_embed = embeds.agent
|
||||
|
||||
one_agent_embed = agent_embed[:, -1:, agent_index]
|
||||
|
||||
reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0)
|
||||
@ -3156,8 +2968,6 @@ class DynamicsWorldModel(Module):
|
||||
# maybe store agent embed
|
||||
|
||||
if store_agent_embed:
|
||||
agent_embed = embeds.agent
|
||||
|
||||
one_agent_embed = agent_embed[:, -1:, agent_index]
|
||||
acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
|
||||
|
||||
@ -3250,10 +3060,10 @@ class DynamicsWorldModel(Module):
|
||||
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
|
||||
out = video if return_decoded_video else latents
|
||||
|
||||
if not return_time_cache:
|
||||
if not return_time_kv_cache:
|
||||
return out
|
||||
|
||||
return out, time_cache
|
||||
return out, time_kv_cache
|
||||
|
||||
# returning agent actions, rewards, and log probs + values for policy optimization
|
||||
|
||||
@ -3283,10 +3093,10 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
gen.values = decoded_values
|
||||
|
||||
if not return_time_cache:
|
||||
if not return_time_kv_cache:
|
||||
return gen
|
||||
|
||||
return gen, time_cache
|
||||
return gen, time_kv_cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -3305,7 +3115,7 @@ class DynamicsWorldModel(Module):
|
||||
discrete_action_types = None, # (na)
|
||||
continuous_action_types = None, # (na)
|
||||
proprio = None, # (b t dp)
|
||||
time_cache = None,
|
||||
time_kv_cache = None,
|
||||
return_pred_only = False,
|
||||
latent_is_noised = False,
|
||||
return_all_losses = False,
|
||||
@ -3455,13 +3265,9 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
||||
|
||||
# empty token
|
||||
|
||||
empty_token = agent_tokens[:, :, 0:0]
|
||||
|
||||
# maybe reward tokens
|
||||
|
||||
reward_tokens = empty_token
|
||||
reward_tokens = agent_tokens[:, :, 0:0]
|
||||
|
||||
if exists(rewards):
|
||||
two_hot_encoding = self.reward_encoder(rewards)
|
||||
@ -3501,13 +3307,6 @@ class DynamicsWorldModel(Module):
|
||||
else:
|
||||
noised_proprio = proprio
|
||||
|
||||
# maybe state prediction token
|
||||
|
||||
if self.should_pred_state:
|
||||
state_pred_token = repeat(self.state_pred_token, 'd -> b t 1 d', b = batch, t = time)
|
||||
else:
|
||||
state_pred_token = empty_token
|
||||
|
||||
# maybe create the action tokens
|
||||
|
||||
if exists(discrete_actions) or exists(continuous_actions):
|
||||
@ -3532,11 +3331,11 @@ class DynamicsWorldModel(Module):
|
||||
action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1])
|
||||
|
||||
else:
|
||||
action_tokens = empty_token # else empty off agent tokens
|
||||
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
|
||||
|
||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||
|
||||
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, state_pred_token, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False):
|
||||
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
|
||||
|
||||
# latents to spatial tokens
|
||||
|
||||
@ -3585,15 +3384,15 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# pack to tokens for attending
|
||||
|
||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, state_pred_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
||||
|
||||
# attention
|
||||
|
||||
tokens, intermediates = self.transformer(tokens, cache = time_cache, return_intermediates = True)
|
||||
tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
|
||||
|
||||
# unpack
|
||||
|
||||
flow_token, space_tokens, proprio_token, state_pred_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||
flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||
|
||||
# pooling
|
||||
|
||||
@ -3605,54 +3404,42 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
if self.has_proprio:
|
||||
pred_proprio = self.to_proprio_pred(proprio_token)
|
||||
else:
|
||||
pred_proprio = None
|
||||
|
||||
# maybe state pred
|
||||
|
||||
if self.should_pred_state:
|
||||
pred_state = self.to_state_pred(state_pred_token)
|
||||
else:
|
||||
pred_state = None
|
||||
pred = (pred, pred_proprio)
|
||||
|
||||
# returning
|
||||
|
||||
predictions = Predictions(pred, pred_proprio, pred_state)
|
||||
|
||||
embeds = Embeds(agent_tokens, state_pred_token)
|
||||
|
||||
if not return_agent_tokens:
|
||||
return predictions
|
||||
return pred
|
||||
|
||||
if not return_time_cache:
|
||||
return predictions, embeds
|
||||
if not return_time_kv_cache:
|
||||
return pred, agent_tokens
|
||||
|
||||
return predictions, (embeds, intermediates)
|
||||
return pred, (agent_tokens, next_time_kv_cache)
|
||||
|
||||
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
||||
|
||||
_get_prediction = partial(get_prediction, state_pred_token = state_pred_token, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
|
||||
_get_prediction = partial(get_prediction, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
|
||||
|
||||
# forward the network
|
||||
|
||||
pred, (embeds, intermediates) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_cache = True)
|
||||
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
|
||||
|
||||
if return_pred_only:
|
||||
if not return_intermediates:
|
||||
return pred
|
||||
|
||||
return pred, (embeds, intermediates)
|
||||
return pred, (encoded_agent_tokens, next_time_kv_cache)
|
||||
|
||||
# pack the predictions to calculate flow for different modalities all at once
|
||||
|
||||
if self.has_proprio:
|
||||
packed_pred, for_flow_loss_packed_shape = pack((pred.flow, pred.proprioception), 'b t *')
|
||||
pred, for_flow_loss_packed_shape = pack(pred, 'b t *')
|
||||
|
||||
noised, _ = pack((noised_latents, noised_proprio), 'b t *')
|
||||
data, _ = pack((latents, proprio), 'b t *')
|
||||
noise, _ = pack((noise, proprio_noise), 'b t *')
|
||||
else:
|
||||
packed_pred = pred.flow
|
||||
noised = noised_latents
|
||||
data = latents
|
||||
|
||||
@ -3671,10 +3458,9 @@ class DynamicsWorldModel(Module):
|
||||
pred = fn(noised, noised_proprio, *args, **kwargs)
|
||||
|
||||
if self.has_proprio:
|
||||
packed_flow, _ = pack((pred.flow, pred.proprioception), 'b t *')
|
||||
return packed_flow
|
||||
pred, _ = pack(pred, 'b t *')
|
||||
|
||||
return pred.flow
|
||||
return pred
|
||||
return inner
|
||||
|
||||
wrapped_get_prediction = maybe_pack_unpack(_get_prediction)
|
||||
@ -3741,12 +3527,12 @@ class DynamicsWorldModel(Module):
|
||||
# need to convert x-space to v-space
|
||||
|
||||
if is_x_space:
|
||||
packed_pred = (packed_pred - noised) / (1. - first_times)
|
||||
pred = (pred - noised) / (1. - first_times)
|
||||
maybe_shortcut_loss_weight = (1. - first_times) ** 2
|
||||
|
||||
# mse loss
|
||||
|
||||
flow_losses = F.mse_loss(packed_pred, pred_target, reduction = 'none')
|
||||
flow_losses = F.mse_loss(pred, pred_target, reduction = 'none')
|
||||
|
||||
flow_losses = flow_losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
|
||||
|
||||
@ -3778,8 +3564,6 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
if exists(rewards):
|
||||
|
||||
encoded_agent_tokens = embeds.agent
|
||||
|
||||
if rewards.ndim == 2: # (b t)
|
||||
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
|
||||
|
||||
@ -3800,18 +3584,6 @@ class DynamicsWorldModel(Module):
|
||||
else:
|
||||
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
||||
|
||||
# maybe autoregressive state prediction loss
|
||||
|
||||
state_pred_loss = self.zero
|
||||
|
||||
if self.should_pred_state:
|
||||
pred_latent, latent_to_pred = pred.state[:, :-1], latents[:, 1:]
|
||||
|
||||
pred_latent_mean, pred_latent_log_var = pred_latent.unbind(dim = -1)
|
||||
pred_latent_var = pred_latent_log_var.exp()
|
||||
|
||||
state_pred_loss = F.gaussian_nll_loss(pred_latent_mean, latent_to_pred, var = pred_latent_var)
|
||||
|
||||
# maybe autoregressive action loss
|
||||
|
||||
discrete_action_loss = self.zero
|
||||
@ -3879,7 +3651,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# handle loss normalization
|
||||
|
||||
losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss, state_pred_loss)
|
||||
losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss)
|
||||
|
||||
if exists(self.flow_loss_normalizer):
|
||||
flow_loss = self.flow_loss_normalizer(flow_loss, update_ema = update_loss_ema)
|
||||
@ -3899,8 +3671,7 @@ class DynamicsWorldModel(Module):
|
||||
flow_loss * self.latent_flow_loss_weight +
|
||||
(reward_loss * self.reward_loss_weight).sum() +
|
||||
(discrete_action_loss * self.discrete_action_loss_weight).sum() +
|
||||
(continuous_action_loss * self.continuous_action_loss_weight).sum() +
|
||||
(state_pred_loss * self.state_pred_loss_weight)
|
||||
(continuous_action_loss * self.continuous_action_loss_weight).sum()
|
||||
)
|
||||
|
||||
if not return_all_losses:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.1.24"
|
||||
version = "0.1.12"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -15,8 +15,7 @@ def exists(v):
|
||||
@param('condition_on_actions', (False, True))
|
||||
@param('num_residual_streams', (1, 4))
|
||||
@param('add_reward_embed_to_agent_token', (False, True))
|
||||
@param('add_state_pred_head', (False, True))
|
||||
@param('use_time_cache', (False, True))
|
||||
@param('use_time_kv_cache', (False, True))
|
||||
@param('var_len', (False, True))
|
||||
def test_e2e(
|
||||
pred_orig_latent,
|
||||
@ -29,8 +28,7 @@ def test_e2e(
|
||||
condition_on_actions,
|
||||
num_residual_streams,
|
||||
add_reward_embed_to_agent_token,
|
||||
add_state_pred_head,
|
||||
use_time_cache,
|
||||
use_time_kv_cache,
|
||||
var_len
|
||||
):
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
@ -73,13 +71,12 @@ def test_e2e(
|
||||
pred_orig_latent = pred_orig_latent,
|
||||
num_discrete_actions = 4,
|
||||
attn_dim_head = 16,
|
||||
attn_heads = heads,
|
||||
attn_kwargs = dict(
|
||||
heads = heads,
|
||||
query_heads = query_heads,
|
||||
),
|
||||
prob_no_shortcut_train = prob_no_shortcut_train,
|
||||
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
|
||||
add_state_pred_head = add_state_pred_head,
|
||||
num_residual_streams = num_residual_streams
|
||||
)
|
||||
|
||||
@ -126,7 +123,7 @@ def test_e2e(
|
||||
image_width = 128,
|
||||
batch_size = 2,
|
||||
return_rewards_per_frame = True,
|
||||
use_time_cache = use_time_cache
|
||||
use_time_kv_cache = use_time_kv_cache
|
||||
)
|
||||
|
||||
assert generations.video.shape == (2, 3, 10, 128, 128)
|
||||
@ -620,9 +617,9 @@ def test_cache_generate():
|
||||
num_residual_streams = 1
|
||||
)
|
||||
|
||||
generated, time_cache = dynamics.generate(1, return_time_cache = True)
|
||||
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
|
||||
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
|
||||
generated, time_kv_cache = dynamics.generate(1, return_time_kv_cache = True)
|
||||
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
||||
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
||||
|
||||
@param('vectorized', (False, True))
|
||||
@param('use_pmpo', (False, True))
|
||||
@ -646,9 +643,7 @@ def test_online_rl(
|
||||
dim_latent = 16,
|
||||
patch_size = 32,
|
||||
attn_dim_head = 16,
|
||||
num_latent_tokens = 1,
|
||||
image_height = 256,
|
||||
image_width = 256,
|
||||
num_latent_tokens = 1
|
||||
)
|
||||
|
||||
world_model_and_policy = DynamicsWorldModel(
|
||||
@ -682,12 +677,10 @@ def test_online_rl(
|
||||
|
||||
# manually
|
||||
|
||||
dream_experience = world_model_and_policy.generate(10, batch_size = 1, store_agent_embed = store_agent_embed, return_for_policy_optimization = True)
|
||||
|
||||
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
||||
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
||||
|
||||
combined_experience = combine_experiences([dream_experience, one_experience, another_experience])
|
||||
combined_experience = combine_experiences([one_experience, another_experience])
|
||||
|
||||
# quick test moving the experience to different devices
|
||||
|
||||
@ -810,22 +803,3 @@ def test_epo():
|
||||
|
||||
fitness = torch.randn(16,)
|
||||
dynamics.evolve_(fitness)
|
||||
|
||||
def test_images_to_video_tokenizer():
|
||||
import torch
|
||||
from dreamer4 import VideoTokenizer, DynamicsWorldModel, AxialSpaceTimeTransformer
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
dim = 512,
|
||||
dim_latent = 32,
|
||||
patch_size = 32,
|
||||
image_height = 256,
|
||||
image_width = 256,
|
||||
encoder_add_decor_aux_loss = True
|
||||
)
|
||||
|
||||
images = torch.randn(2, 3, 256, 256)
|
||||
loss, (losses, recon_images) = tokenizer(images, return_intermediates = True)
|
||||
loss.backward()
|
||||
|
||||
assert images.shape == recon_images.shape
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user