diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 02db5b9..dc81bb3 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -5,7 +5,7 @@ from math import ceil, log2 from random import random from contextlib import nullcontext from collections import namedtuple -from functools import partial +from functools import partial, wraps from dataclasses import dataclass, asdict import torch @@ -75,6 +75,7 @@ WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_ class Experience: latents: Tensor video: Tensor | None = None + proprio: Tensor | None = None rewards: Tensor | None = None actions: tuple[Tensor, Tensor] | None = None log_probs: tuple[Tensor, Tensor] | None = None @@ -130,6 +131,9 @@ def default(v, d): def first(arr): return arr[0] +def xnor(x, y): + return not (x ^ y) + def has_at_least_one(*bools): return sum([*map(int, bools)]) > 0 @@ -1731,6 +1735,7 @@ class DynamicsWorldModel(Module): num_latent_tokens = None, num_agents = 1, num_tasks = 0, + dim_proprio = None, reward_encoder_kwargs: dict = dict(), depth = 4, pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space) @@ -1816,6 +1821,19 @@ class DynamicsWorldModel(Module): Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space) ) + # proprioception + + self.has_proprio = exists(dim_proprio) + self.dim_proprio = dim_proprio + + if self.has_proprio: + self.to_proprio_token = nn.Linear(dim_proprio, dim) + + self.to_proprio_pred = Sequential( + RMSNorm(dim), + nn.Linear(dim, dim_proprio) + ) + # register tokens self.num_register_tokens = num_register_tokens @@ -2221,7 +2239,7 @@ class DynamicsWorldModel(Module): entropy_loss * self.policy_entropy_weight ) - # maye take policy optimizer step + # maybe take policy optimizer step if exists(policy_optim): total_policy_loss.backward() @@ -2271,10 +2289,11 @@ class DynamicsWorldModel(Module): return_rewards_per_frame = False, return_agent_actions = False, return_log_probs_and_values = False, - return_time_kv_cache = False + return_time_kv_cache = False, ): # (b t n d) | (b c t h w) + has_proprio = self.has_proprio was_training = self.training self.eval() @@ -2301,7 +2320,14 @@ class DynamicsWorldModel(Module): latents = empty((batch_size, 0, *latent_shape), device = self.device) - past_context_noise = latents.clone() + past_latents_context_noise = latents.clone() + + # maybe internal state + + if has_proprio: + proprio = empty((batch_size, 0, self.dim_proprio), device = self.device) + + past_proprio_context_noise = proprio.clone() # maybe return actions @@ -2327,17 +2353,35 @@ class DynamicsWorldModel(Module): while latents.shape[1] < time_steps: curr_time_steps = latents.shape[1] + noised_latent = randn((batch_size, 1, *latent_shape), device = self.device) + noised_proprio = None + + if has_proprio: + noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device) + for step in range(num_steps): is_last_step = (step + 1) == num_steps signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device) - noised_context = latents.lerp(past_context_noise, context_signal_noise) # the paragraph after eq (8) + # noising past latent context + + noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8) noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d') + # handle proprio + + noised_proprio_with_context = None + + if has_proprio: + noised_proprio_context = proprio.lerp(past_proprio_context_noise, context_signal_noise) + noised_proprio_with_context, _ = pack((noised_proprio_context, noised_proprio), 'b * d') + + # proper signal levels + signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1) pred, (agent_embed, next_time_kv_cache) = self.forward( @@ -2348,6 +2392,7 @@ class DynamicsWorldModel(Module): tasks = tasks, discrete_actions = decoded_discrete_actions, continuous_actions = decoded_continuous_actions, + proprio = noised_proprio_with_context, time_kv_cache = time_kv_cache, latent_is_noised = True, return_pred_only = True, @@ -2357,24 +2402,44 @@ class DynamicsWorldModel(Module): if use_time_kv_cache and is_last_step: time_kv_cache = next_time_kv_cache + # maybe proprio + + if has_proprio: + pred, pred_proprio = pred + # unpack pred _, pred = unpack(pred, pack_context_shape, 'b * n d') + if has_proprio: + _, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d') + + # derive flow, based on whether in x-space or not - if self.pred_orig_latent: - times = self.get_times_from_signal_level(signal_levels, noised_latent) - flow = (pred - noised_latent) / (1. - times) - else: - flow = pred + def denoise_step(pred, noised, signal_levels): + if self.pred_orig_latent: + times = self.get_times_from_signal_level(signal_levels) + aligned_times = align_dims_left(times, noised) + + flow = (pred - noised) / (1. - aligned_times) + else: + flow = pred + + return flow * (step_size / self.max_steps) # denoise - noised_latent += flow * (step_size / self.max_steps) + noised_latent += denoise_step(pred, noised_latent, signal_levels) + + if has_proprio: + noised_proprio += denoise_step(pred_proprio, noised_proprio, signal_levels) denoised_latent = noised_latent # it is now denoised + if has_proprio: + denoised_proprio = noised_proprio + # take care of the rewards by predicting on the agent token embedding on the last denoising step if return_rewards_per_frame: @@ -2421,7 +2486,14 @@ class DynamicsWorldModel(Module): # add new fixed context noise for the temporal consistency - past_context_noise = cat((past_context_noise, randn_like(denoised_latent)), dim = 1) + past_latents_context_noise = cat((past_latents_context_noise, randn_like(denoised_latent)), dim = 1) + + # handle proprio + + if has_proprio: + proprio = cat((proprio, denoised_proprio), dim = 1) + + past_proprio_context_noise = cat((past_proprio_context_noise, randn_like(denoised_proprio)), dim = 1) # restore state @@ -2443,7 +2515,7 @@ class DynamicsWorldModel(Module): # only return video or latent if not requesting anything else, for first stage training - if not has_at_least_one(return_rewards_per_frame, return_agent_actions): + if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio): out = video if return_decoded_video else latents if not return_time_kv_cache: @@ -2456,6 +2528,7 @@ class DynamicsWorldModel(Module): gen = Experience( latents = latents, video = video, + proprio = proprio if has_proprio else None, step_size = step_size, agent_index = agent_index, is_from_world_model = True @@ -2492,6 +2565,7 @@ class DynamicsWorldModel(Module): continuous_actions = None, # (b t na) | (b t-1 na) discrete_action_types = None, # (na) continuous_action_types = None, # (na) + proprio = None, # (b t dp) time_kv_cache = None, return_pred_only = False, latent_is_noised = False, @@ -2587,16 +2661,17 @@ class DynamicsWorldModel(Module): # times is from 0 to 1 - times = self.get_times_from_signal_level(signal_levels, latents) + times = self.get_times_from_signal_level(signal_levels) if not latent_is_noised: # get the noise noise = randn_like(latents) + aligned_times = align_dims_left(times, latents) # noise from 0 as noise to 1 as data - noised_latents = noise.lerp(latents, times) + noised_latents = noise.lerp(latents, aligned_times) else: noised_latents = latents @@ -2644,6 +2719,27 @@ class DynamicsWorldModel(Module): reward_tokens = add('1 d, b t d', self.reward_learned_embed, reward_tokens) + # maybe proprioception + + assert xnor(self.has_proprio, exists(proprio)), 'proprio must be passed in if `dim_proprio` is set and vice versa' + + noised_proprio = None + + if self.has_proprio: + + if not latent_is_noised: + # get the noise + + proprio_noise = randn_like(proprio) + aligned_times = align_dims_left(times, proprio) + + # noise from 0 as noise to 1 as data + + noised_proprio = proprio_noise.lerp(proprio, aligned_times) + + else: + noised_proprio = proprio + # maybe create the action tokens if exists(discrete_actions) or exists(continuous_actions): @@ -2672,7 +2768,7 @@ class DynamicsWorldModel(Module): # main function, needs to be defined as such for shortcut training - additional calls for consistency loss - def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False): + def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False): # latents to spatial tokens space_tokens = self.latents_to_spatial_tokens(noised_latents) @@ -2694,6 +2790,13 @@ class DynamicsWorldModel(Module): registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time) + # maybe proprio + + if exists(noised_proprio): + proprio_token = self.to_proprio_token(noised_proprio) + else: + proprio_token = registers[:, :, 0:0] + # determine signal + step size embed for their diffusion forcing + shortcut signal_embed = self.signal_levels_embed(signal_levels) @@ -2706,7 +2809,7 @@ class DynamicsWorldModel(Module): # pack to tokens for attending - tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d') + tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d') # attention @@ -2714,7 +2817,7 @@ class DynamicsWorldModel(Module): # unpack - flow_token, space_tokens, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d') + flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d') # pooling @@ -2722,6 +2825,15 @@ class DynamicsWorldModel(Module): pred = self.to_latent_pred(space_tokens) + # maybe proprio + + if self.has_proprio: + pred_proprio = self.to_proprio_pred(proprio_token) + + pred = (pred, pred_proprio) + + # returning + if not return_agent_tokens: return pred @@ -2736,7 +2848,7 @@ class DynamicsWorldModel(Module): # forward the network - pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True) + pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True) if return_pred_only: if not return_intermediates: @@ -2744,6 +2856,40 @@ class DynamicsWorldModel(Module): return pred, (encoded_agent_tokens, next_time_kv_cache) + # pack the predictions to calculate flow for different modalities all at once + + if self.has_proprio: + pred, for_flow_loss_packed_shape = pack(pred, 'b t *') + + noised, _ = pack((noised_latents, noised_proprio), 'b t *') + data, _ = pack((latents, proprio), 'b t *') + noise, _ = pack((noise, proprio_noise), 'b t *') + else: + noised = noised_latents + data = latents + + # wrapper function for maybe unpacking and packing modalities for doing flow math in unison + + def maybe_pack_unpack(fn): + @wraps(fn) + @torch.no_grad() + def inner(noised, *args, **kwargs): + + noised_proprio = None + + if self.has_proprio: + noised, noised_proprio = unpack(noised, for_flow_loss_packed_shape, 'b t *') + + pred = fn(noised, noised_proprio, *args, **kwargs) + + if self.has_proprio: + pred, _ = pack(pred, 'b t *') + + return pred + return inner + + wrapped_get_prediction = maybe_pack_unpack(_get_prediction) + # determine the target for the loss pred_target = None @@ -2759,46 +2905,45 @@ class DynamicsWorldModel(Module): # x-space as in paper is in else clause if is_v_space_pred: - pred_target = flow = latents - noise + pred_target = flow = data - noise else: - pred_target = latents + pred_target = data else: # shortcut training - Frans et al. https://arxiv.org/abs/2410.12557 # basically a consistency loss where you ensure quantity of two half steps equals one step # dreamer then makes it works for x-space with some math - get_prediction_no_grad = torch.no_grad()(_get_prediction) - step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2 half_step_size = 2 ** step_sizes_log2_minus_one - first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one) + first_step_pred = wrapped_get_prediction(noised, signal_levels, step_sizes_log2_minus_one) # first derive b' if is_v_space_pred: first_step_pred_flow = first_step_pred else: - first_times = self.get_times_from_signal_level(signal_levels, noised_latents) - first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times) + first_times = self.get_times_from_signal_level(signal_levels, noised) + + first_step_pred_flow = (first_step_pred - noised) / (1. - first_times) # take a half step - half_step_size_align_left = align_dims_left(half_step_size, noised_latents) + half_step_size_align_left = align_dims_left(half_step_size, noised) - denoised_latent = noised_latents + first_step_pred_flow * (half_step_size_align_left / self.max_steps) + denoised = noised + first_step_pred_flow * (half_step_size_align_left / self.max_steps) # get second prediction for b'' signal_levels_plus_half_step = signal_levels + half_step_size[:, None] - second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one) + second_step_pred = wrapped_get_prediction(denoised, signal_levels_plus_half_step, step_sizes_log2_minus_one) if is_v_space_pred: second_step_pred_flow = second_step_pred else: - second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised_latent) - second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times) + second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised) + second_step_pred_flow = (second_step_pred - denoised) / (1. - second_times) # pred target is sg(b' + b'') / 2 @@ -2807,7 +2952,7 @@ class DynamicsWorldModel(Module): # need to convert x-space to v-space if is_x_space: - pred = (pred - noised_latents) / (1. - first_times) + pred = (pred - noised) / (1. - first_times) maybe_shortcut_loss_weight = (1. - first_times) ** 2 # mse loss @@ -2820,6 +2965,8 @@ class DynamicsWorldModel(Module): if exists(self.loss_weight_fn): loss_weight = self.loss_weight_fn(times) + loss_weight = align_dims_left(loss_weight, flow_losses) + flow_losses = flow_losses * loss_weight flow_loss = flow_losses.mean() diff --git a/pyproject.toml b/pyproject.toml index 0d7b0af..a11bf77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.69" +version = "0.0.70" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 5afe5eb..39ed0ce 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -2,6 +2,9 @@ import pytest param = pytest.mark.parametrize import torch +def exists(v): + return v is not None + @param('pred_orig_latent', (False, True)) @param('grouped_query_attn', (False, True)) @param('dynamics_with_video_input', (False, True)) @@ -664,3 +667,55 @@ def test_online_rl( ) trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized) + +def test_proprioception(): + from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel + + tokenizer = VideoTokenizer( + 512, + dim_latent = 32, + patch_size = 32, + encoder_depth = 2, + decoder_depth = 2, + time_block_every = 2, + attn_heads = 8, + image_height = 256, + image_width = 256, + attn_kwargs = dict( + query_heads = 16 + ) + ) + + dynamics = DynamicsWorldModel( + 512, + num_agents = 1, + video_tokenizer = tokenizer, + dim_latent = 32, + dim_proprio = 21, + num_tasks = 4, + num_discrete_actions = 4, + num_residual_streams = 1 + ) + + video = torch.randn(2, 3, 10, 256, 256) + rewards = torch.randn(2, 10) + proprio = torch.randn(2, 10, 21) + discrete_actions = torch.randint(0, 4, (2, 10, 1)) + tasks = torch.randint(0, 4, (2,)) + + loss = dynamics( + video = video, + rewards = rewards, + tasks = tasks, + proprio = proprio, + discrete_actions = discrete_actions + ) + + loss.backward() + + generations = dynamics.generate( + 4, + batch_size = 2, + ) + + assert exists(generations.proprio)