Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5bb027b386 | ||
|
|
9efe269688 | ||
|
|
fb8c3793b4 | ||
|
|
fb6d69f43a | ||
|
|
125693ce1c |
@ -71,14 +71,20 @@ except ImportError:
|
||||
|
||||
LinearNoBias = partial(Linear, bias = False)
|
||||
|
||||
VideoTokenizerIntermediates = namedtuple('VideoTokenizerIntermediates', ('losses', 'recon'))
|
||||
|
||||
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
|
||||
|
||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions', 'state_pred'))
|
||||
|
||||
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
|
||||
|
||||
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs', 'next_rnn_hiddens'))
|
||||
|
||||
Predictions = namedtuple('Predictions', ('flow', 'proprioception', 'state'))
|
||||
|
||||
Embeds = namedtuple('Embeds', ['agent', 'state_pred'])
|
||||
|
||||
MaybeTensor = Tensor | None
|
||||
|
||||
@dataclass
|
||||
@ -1936,11 +1942,23 @@ class VideoTokenizer(Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video, # (b c t h w)
|
||||
video_or_image, # (b c t h w) | (b c h w)
|
||||
return_latents = False,
|
||||
mask_patches = None,
|
||||
return_all_losses = False
|
||||
return_intermediates = False,
|
||||
):
|
||||
|
||||
# handle image pretraining
|
||||
|
||||
is_image = video_or_image.ndim == 4
|
||||
|
||||
if is_image:
|
||||
video = rearrange(video_or_image, 'b c h w -> b c 1 h w')
|
||||
else:
|
||||
video = video_or_image
|
||||
|
||||
# shapes
|
||||
|
||||
batch, _, time, height, width = video.shape
|
||||
patch_size, device = self.patch_size, video.device
|
||||
|
||||
@ -2020,12 +2038,21 @@ class VideoTokenizer(Module):
|
||||
space_decorr_loss * self.decorr_aux_loss_weight
|
||||
)
|
||||
|
||||
if not return_all_losses:
|
||||
if not return_intermediates:
|
||||
return total_loss
|
||||
|
||||
losses = (recon_loss, lpips_loss, decorr_loss)
|
||||
losses = TokenizerLosses(recon_loss, lpips_loss, time_decorr_loss, space_decorr_loss)
|
||||
|
||||
return total_loss, TokenizerLosses(*losses)
|
||||
out = losses
|
||||
|
||||
# handle returning of reconstructed, and image pretraining
|
||||
|
||||
if is_image:
|
||||
recon_video = rearrange(recon_video, 'b c 1 h w -> b c h w')
|
||||
|
||||
out = (losses, recon_video)
|
||||
|
||||
return total_loss, out
|
||||
|
||||
# dynamics model, axial space-time transformer
|
||||
|
||||
@ -2058,6 +2085,9 @@ class DynamicsWorldModel(Module):
|
||||
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
||||
add_reward_embed_to_agent_token = False,
|
||||
add_reward_embed_dropout = 0.1,
|
||||
add_state_pred_head = False,
|
||||
state_pred_loss_weight = 0.1,
|
||||
state_entropy_bonus_weight = 0.05,
|
||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||
num_continuous_actions = 0,
|
||||
continuous_norm_stats = None,
|
||||
@ -2181,6 +2211,25 @@ class DynamicsWorldModel(Module):
|
||||
self.pred_orig_latent = pred_orig_latent # x-space or v-space
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
|
||||
# state prediction, for state entropy bonus
|
||||
|
||||
self.add_state_pred_head = add_state_pred_head
|
||||
self.state_pred_loss_weight = state_pred_loss_weight
|
||||
|
||||
self.should_pred_state = add_state_pred_head and state_pred_loss_weight > 0.
|
||||
|
||||
if self.should_pred_state:
|
||||
self.state_pred_token = nn.Parameter(torch.randn(dim) * 1e-2)
|
||||
|
||||
self.to_state_pred = Sequential(
|
||||
RMSNorm(dim),
|
||||
nn.Linear(dim, num_latent_tokens * dim_latent * 2),
|
||||
Rearrange('... (n d two) -> ... n d two', n = num_latent_tokens, two = 2)
|
||||
)
|
||||
|
||||
self.state_entropy_bonus_weight = state_entropy_bonus_weight
|
||||
self.add_state_entropy_bonus = self.should_pred_state and state_entropy_bonus_weight > 0.
|
||||
|
||||
# reinforcement related
|
||||
|
||||
# they sum all the actions into a single token
|
||||
@ -2471,7 +2520,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
latents = self.video_tokenizer(video, return_latents = True)
|
||||
|
||||
_, (agent_embed, next_time_cache) = self.forward(
|
||||
_, (embeds, next_time_cache) = self.forward(
|
||||
latents = latents,
|
||||
signal_levels = self.max_steps - 1,
|
||||
step_sizes = step_size,
|
||||
@ -2491,6 +2540,8 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# get one agent
|
||||
|
||||
agent_embed = embeds.agent
|
||||
|
||||
one_agent_embed = agent_embed[..., -1:, agent_index, :]
|
||||
|
||||
# values
|
||||
@ -2545,6 +2596,17 @@ class DynamicsWorldModel(Module):
|
||||
elif len(env_step_out) == 5:
|
||||
next_frame, reward, terminated, truncated, info = env_step_out
|
||||
|
||||
# maybe add state entropy bonus
|
||||
|
||||
if self.add_state_entropy_bonus:
|
||||
state_pred_token = embeds.state_pred
|
||||
|
||||
state_pred = self.to_state_pred(state_pred_token)
|
||||
|
||||
state_pred_log_variance = state_pred[..., 1].sum()
|
||||
|
||||
reward = reward + state_pred_log_variance * self.state_entropy_bonus_weight
|
||||
|
||||
# update episode lens
|
||||
|
||||
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
||||
@ -2723,7 +2785,7 @@ class DynamicsWorldModel(Module):
|
||||
):
|
||||
|
||||
with world_model_forward_context():
|
||||
_, (agent_embeds, _) = self.forward(
|
||||
_, (embeds, _) = self.forward(
|
||||
latents = latents,
|
||||
signal_levels = self.max_steps - 1,
|
||||
step_sizes = step_size,
|
||||
@ -2735,7 +2797,7 @@ class DynamicsWorldModel(Module):
|
||||
return_intermediates = True
|
||||
)
|
||||
|
||||
agent_embeds = agent_embeds[..., agent_index, :]
|
||||
agent_embeds = embeds.agent[..., agent_index, :]
|
||||
|
||||
# maybe detach agent embed
|
||||
|
||||
@ -3015,7 +3077,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
||||
|
||||
pred, (agent_embed, next_time_cache) = self.forward(
|
||||
pred, (embeds, next_time_cache) = self.forward(
|
||||
latents = noised_latent_with_context,
|
||||
signal_levels = signal_levels_with_context,
|
||||
step_sizes = step_size,
|
||||
@ -3042,8 +3104,10 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# maybe proprio
|
||||
|
||||
if has_proprio:
|
||||
pred, pred_proprio = pred
|
||||
# maybe proprio
|
||||
|
||||
pred_proprio = pred.proprioception
|
||||
pred = pred.flow
|
||||
|
||||
# unpack pred
|
||||
|
||||
@ -3080,6 +3144,8 @@ class DynamicsWorldModel(Module):
|
||||
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
||||
|
||||
if return_rewards_per_frame:
|
||||
agent_embed = embeds.agent
|
||||
|
||||
one_agent_embed = agent_embed[:, -1:, agent_index]
|
||||
|
||||
reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0)
|
||||
@ -3090,6 +3156,8 @@ class DynamicsWorldModel(Module):
|
||||
# maybe store agent embed
|
||||
|
||||
if store_agent_embed:
|
||||
agent_embed = embeds.agent
|
||||
|
||||
one_agent_embed = agent_embed[:, -1:, agent_index]
|
||||
acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
|
||||
|
||||
@ -3387,9 +3455,13 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
||||
|
||||
# empty token
|
||||
|
||||
empty_token = agent_tokens[:, :, 0:0]
|
||||
|
||||
# maybe reward tokens
|
||||
|
||||
reward_tokens = agent_tokens[:, :, 0:0]
|
||||
reward_tokens = empty_token
|
||||
|
||||
if exists(rewards):
|
||||
two_hot_encoding = self.reward_encoder(rewards)
|
||||
@ -3429,6 +3501,13 @@ class DynamicsWorldModel(Module):
|
||||
else:
|
||||
noised_proprio = proprio
|
||||
|
||||
# maybe state prediction token
|
||||
|
||||
if self.should_pred_state:
|
||||
state_pred_token = repeat(self.state_pred_token, 'd -> b t 1 d', b = batch, t = time)
|
||||
else:
|
||||
state_pred_token = empty_token
|
||||
|
||||
# maybe create the action tokens
|
||||
|
||||
if exists(discrete_actions) or exists(continuous_actions):
|
||||
@ -3453,11 +3532,11 @@ class DynamicsWorldModel(Module):
|
||||
action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1])
|
||||
|
||||
else:
|
||||
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
|
||||
action_tokens = empty_token # else empty off agent tokens
|
||||
|
||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||
|
||||
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False):
|
||||
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, state_pred_token, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False):
|
||||
|
||||
# latents to spatial tokens
|
||||
|
||||
@ -3506,7 +3585,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# pack to tokens for attending
|
||||
|
||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, state_pred_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
||||
|
||||
# attention
|
||||
|
||||
@ -3514,7 +3593,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# unpack
|
||||
|
||||
flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||
flow_token, space_tokens, proprio_token, state_pred_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||
|
||||
# pooling
|
||||
|
||||
@ -3526,42 +3605,54 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
if self.has_proprio:
|
||||
pred_proprio = self.to_proprio_pred(proprio_token)
|
||||
else:
|
||||
pred_proprio = None
|
||||
|
||||
pred = (pred, pred_proprio)
|
||||
# maybe state pred
|
||||
|
||||
if self.should_pred_state:
|
||||
pred_state = self.to_state_pred(state_pred_token)
|
||||
else:
|
||||
pred_state = None
|
||||
|
||||
# returning
|
||||
|
||||
predictions = Predictions(pred, pred_proprio, pred_state)
|
||||
|
||||
embeds = Embeds(agent_tokens, state_pred_token)
|
||||
|
||||
if not return_agent_tokens:
|
||||
return pred
|
||||
return predictions
|
||||
|
||||
if not return_time_cache:
|
||||
return pred, agent_tokens
|
||||
return predictions, embeds
|
||||
|
||||
return pred, (agent_tokens, intermediates)
|
||||
return predictions, (embeds, intermediates)
|
||||
|
||||
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
||||
|
||||
_get_prediction = partial(get_prediction, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
|
||||
_get_prediction = partial(get_prediction, state_pred_token = state_pred_token, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
|
||||
|
||||
# forward the network
|
||||
|
||||
pred, (encoded_agent_tokens, intermediates) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_cache = True)
|
||||
pred, (embeds, intermediates) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_cache = True)
|
||||
|
||||
if return_pred_only:
|
||||
if not return_intermediates:
|
||||
return pred
|
||||
|
||||
return pred, (encoded_agent_tokens, intermediates)
|
||||
return pred, (embeds, intermediates)
|
||||
|
||||
# pack the predictions to calculate flow for different modalities all at once
|
||||
|
||||
if self.has_proprio:
|
||||
pred, for_flow_loss_packed_shape = pack(pred, 'b t *')
|
||||
packed_pred, for_flow_loss_packed_shape = pack((pred.flow, pred.proprioception), 'b t *')
|
||||
|
||||
noised, _ = pack((noised_latents, noised_proprio), 'b t *')
|
||||
data, _ = pack((latents, proprio), 'b t *')
|
||||
noise, _ = pack((noise, proprio_noise), 'b t *')
|
||||
else:
|
||||
packed_pred = pred.flow
|
||||
noised = noised_latents
|
||||
data = latents
|
||||
|
||||
@ -3580,9 +3671,10 @@ class DynamicsWorldModel(Module):
|
||||
pred = fn(noised, noised_proprio, *args, **kwargs)
|
||||
|
||||
if self.has_proprio:
|
||||
pred, _ = pack(pred, 'b t *')
|
||||
packed_flow, _ = pack((pred.flow, pred.proprioception), 'b t *')
|
||||
return packed_flow
|
||||
|
||||
return pred
|
||||
return pred.flow
|
||||
return inner
|
||||
|
||||
wrapped_get_prediction = maybe_pack_unpack(_get_prediction)
|
||||
@ -3649,12 +3741,12 @@ class DynamicsWorldModel(Module):
|
||||
# need to convert x-space to v-space
|
||||
|
||||
if is_x_space:
|
||||
pred = (pred - noised) / (1. - first_times)
|
||||
packed_pred = (packed_pred - noised) / (1. - first_times)
|
||||
maybe_shortcut_loss_weight = (1. - first_times) ** 2
|
||||
|
||||
# mse loss
|
||||
|
||||
flow_losses = F.mse_loss(pred, pred_target, reduction = 'none')
|
||||
flow_losses = F.mse_loss(packed_pred, pred_target, reduction = 'none')
|
||||
|
||||
flow_losses = flow_losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
|
||||
|
||||
@ -3686,6 +3778,8 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
if exists(rewards):
|
||||
|
||||
encoded_agent_tokens = embeds.agent
|
||||
|
||||
if rewards.ndim == 2: # (b t)
|
||||
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
|
||||
|
||||
@ -3706,6 +3800,18 @@ class DynamicsWorldModel(Module):
|
||||
else:
|
||||
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
||||
|
||||
# maybe autoregressive state prediction loss
|
||||
|
||||
state_pred_loss = self.zero
|
||||
|
||||
if self.should_pred_state:
|
||||
pred_latent, latent_to_pred = pred.state[:, :-1], latents[:, 1:]
|
||||
|
||||
pred_latent_mean, pred_latent_log_var = pred_latent.unbind(dim = -1)
|
||||
pred_latent_var = pred_latent_log_var.exp()
|
||||
|
||||
state_pred_loss = F.gaussian_nll_loss(pred_latent_mean, latent_to_pred, var = pred_latent_var)
|
||||
|
||||
# maybe autoregressive action loss
|
||||
|
||||
discrete_action_loss = self.zero
|
||||
@ -3773,7 +3879,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# handle loss normalization
|
||||
|
||||
losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss)
|
||||
losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss, state_pred_loss)
|
||||
|
||||
if exists(self.flow_loss_normalizer):
|
||||
flow_loss = self.flow_loss_normalizer(flow_loss, update_ema = update_loss_ema)
|
||||
@ -3793,7 +3899,8 @@ class DynamicsWorldModel(Module):
|
||||
flow_loss * self.latent_flow_loss_weight +
|
||||
(reward_loss * self.reward_loss_weight).sum() +
|
||||
(discrete_action_loss * self.discrete_action_loss_weight).sum() +
|
||||
(continuous_action_loss * self.continuous_action_loss_weight).sum()
|
||||
(continuous_action_loss * self.continuous_action_loss_weight).sum() +
|
||||
(state_pred_loss * self.state_pred_loss_weight)
|
||||
)
|
||||
|
||||
if not return_all_losses:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.1.20"
|
||||
version = "0.1.24"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -15,6 +15,7 @@ def exists(v):
|
||||
@param('condition_on_actions', (False, True))
|
||||
@param('num_residual_streams', (1, 4))
|
||||
@param('add_reward_embed_to_agent_token', (False, True))
|
||||
@param('add_state_pred_head', (False, True))
|
||||
@param('use_time_cache', (False, True))
|
||||
@param('var_len', (False, True))
|
||||
def test_e2e(
|
||||
@ -28,6 +29,7 @@ def test_e2e(
|
||||
condition_on_actions,
|
||||
num_residual_streams,
|
||||
add_reward_embed_to_agent_token,
|
||||
add_state_pred_head,
|
||||
use_time_cache,
|
||||
var_len
|
||||
):
|
||||
@ -77,6 +79,7 @@ def test_e2e(
|
||||
),
|
||||
prob_no_shortcut_train = prob_no_shortcut_train,
|
||||
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
|
||||
add_state_pred_head = add_state_pred_head,
|
||||
num_residual_streams = num_residual_streams
|
||||
)
|
||||
|
||||
@ -807,3 +810,22 @@ def test_epo():
|
||||
|
||||
fitness = torch.randn(16,)
|
||||
dynamics.evolve_(fitness)
|
||||
|
||||
def test_images_to_video_tokenizer():
|
||||
import torch
|
||||
from dreamer4 import VideoTokenizer, DynamicsWorldModel, AxialSpaceTimeTransformer
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
dim = 512,
|
||||
dim_latent = 32,
|
||||
patch_size = 32,
|
||||
image_height = 256,
|
||||
image_width = 256,
|
||||
encoder_add_decor_aux_loss = True
|
||||
)
|
||||
|
||||
images = torch.randn(2, 3, 256, 256)
|
||||
loss, (losses, recon_images) = tokenizer(images, return_intermediates = True)
|
||||
loss.backward()
|
||||
|
||||
assert images.shape == recon_images.shape
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user